Skip to content
Snippets Groups Projects
Commit b4e270ab authored by MaximeCerise's avatar MaximeCerise
Browse files

ok

parent 3b9fe874
No related branches found
No related tags found
No related merge requests found
import wandb
import gymnasium as gym
import numpy as np
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
from huggingface_sb3 import package_to_hub
def a2c_sb3():
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = Monitor(env)
env = DummyVecEnv([lambda: env])
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = Monitor(env)
env = DummyVecEnv([lambda: env])
wandb.init(
entity="maximecerise-ecl",
project="cartpole-a2c",
sync_tensorboard=True,
monitor_gym=True,
save_code=True
)
wandb.init(
entity="maximecerise-ecl",
project="cartpole-a2c",
sync_tensorboard=True,
monitor_gym=True,
save_code=True
)
model = A2C("MlpPolicy", env, verbose=1, tensorboard_log="./a2c_tensorboard/")
model.learn(total_timesteps=300000)
model.save("a2c_cartpole")
env.close()
eval_env = gym.make("CartPole-v1", render_mode="rgb_array")
eval_env = Monitor(eval_env)
eval_env = DummyVecEnv([lambda: eval_env])
success_count = 0
num_episodes = 100
scores = []
for episode in range(num_episodes):
obs = eval_env.reset()
done = False
episode_reward = 0
while not done:
action, _ = model.predict(obs)
obs, reward, done, info = eval_env.step(action)
done = done[0] # Extraction de la valeur booléenne
episode_reward += reward
scores.append(episode_reward)
if episode_reward >= 200:
success_count += 1
model = A2C("MlpPolicy", env, verbose=1, tensorboard_log="./a2c_tensorboard/")
model.learn(total_timesteps=500000)
wandb.log({
"episode": episode,
"episode_reward": episode_reward,
"success_rate (%)": success_count / (episode + 1) * 100
})
model.save("a2c_cartpole")
env.close()
eval_env = gym.make("CartPole-v1", render_mode="rgb_array")
eval_env = Monitor(eval_env)
eval_env = DummyVecEnv([lambda: eval_env])
video_folder = "./videos/"
eval_env = VecVideoRecorder(eval_env, video_folder, record_video_trigger=lambda x: x == 0, video_length=1000)
success_rate = success_count / num_episodes * 100
avg_score = np.mean(scores)
obs = eval_env.reset()
for _ in range(1000):
action, _ = model.predict(obs)
obs, _, _, _ = eval_env.step(action)
eval_env.close()
wandb.log({
"final_success_rate (%)": success_rate,
"final_average_score": avg_score
})
print(f"Taux de succès du modèle A2C sur {num_episodes} épisodes : {success_rate:.2f}%")
print(f"Score moyen : {avg_score:.2f}")
video_folder = "./videos/"
eval_env = VecVideoRecorder(eval_env, video_folder, record_video_trigger=lambda x: x == 0, video_length=1000)
obs = eval_env.reset()
for _ in range(1000):
action, _ = model.predict(obs)
obs, _, _, _ = eval_env.step(action)
eval_env.close()
package_to_hub(
model=model,
model_name="a2c_cartpole",
model_architecture="A2C",
env_id="CartPole-v1",
eval_env=eval_env,
repo_id="MaximeCerise/a2c_cartpole",
commit_message="add a2c with evaluation"
)
package_to_hub(
model=model,
model_name="a2c_cartpole",
model_architecture="A2C",
env_id="CartPole-v1",
eval_env=eval_env,
repo_id="MaximeCerise/a2c_cartpole",
commit_message="add a2c"
)
wandb.finish()
wandb.finish()
if __name__ == "__main__":
a2c_sb3()
\ No newline at end of file
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment