diff --git a/a2c_sb3_cartpole.py b/a2c_sb3_cartpole.py index 3cce8ed108e21d7267db74b3f5fa82ca08b2234f..b9b5986dcd08cd809386587c2ad9d58d6ee688c3 100644 --- a/a2c_sb3_cartpole.py +++ b/a2c_sb3_cartpole.py @@ -1,52 +1,99 @@ import wandb import gymnasium as gym +import numpy as np from stable_baselines3 import A2C from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder from huggingface_sb3 import package_to_hub +def a2c_sb3(): + env = gym.make("CartPole-v1", render_mode="rgb_array") + env = Monitor(env) + env = DummyVecEnv([lambda: env]) -env = gym.make("CartPole-v1", render_mode="rgb_array") -env = Monitor(env) -env = DummyVecEnv([lambda: env]) + wandb.init( + entity="maximecerise-ecl", + project="cartpole-a2c", + sync_tensorboard=True, + monitor_gym=True, + save_code=True + ) -wandb.init( - entity="maximecerise-ecl", - project="cartpole-a2c", - sync_tensorboard=True, - monitor_gym=True, - save_code=True - ) + model = A2C("MlpPolicy", env, verbose=1, tensorboard_log="./a2c_tensorboard/") + model.learn(total_timesteps=300000) + + + model.save("a2c_cartpole") + env.close() + + + eval_env = gym.make("CartPole-v1", render_mode="rgb_array") + eval_env = Monitor(eval_env) + eval_env = DummyVecEnv([lambda: eval_env]) + + success_count = 0 + num_episodes = 100 + scores = [] + + for episode in range(num_episodes): + obs = eval_env.reset() + done = False + episode_reward = 0 + + while not done: + action, _ = model.predict(obs) + obs, reward, done, info = eval_env.step(action) + done = done[0] # Extraction de la valeur booléenne + episode_reward += reward + scores.append(episode_reward) + if episode_reward >= 200: + success_count += 1 -model = A2C("MlpPolicy", env, verbose=1, tensorboard_log="./a2c_tensorboard/") -model.learn(total_timesteps=500000) + wandb.log({ + "episode": episode, + "episode_reward": episode_reward, + "success_rate (%)": success_count / (episode + 1) * 100 + }) -model.save("a2c_cartpole") -env.close() -eval_env = gym.make("CartPole-v1", render_mode="rgb_array") -eval_env = Monitor(eval_env) -eval_env = DummyVecEnv([lambda: eval_env]) -video_folder = "./videos/" -eval_env = VecVideoRecorder(eval_env, video_folder, record_video_trigger=lambda x: x == 0, video_length=1000) + success_rate = success_count / num_episodes * 100 + avg_score = np.mean(scores) -obs = eval_env.reset() -for _ in range(1000): - action, _ = model.predict(obs) - obs, _, _, _ = eval_env.step(action) -eval_env.close() + wandb.log({ + "final_success_rate (%)": success_rate, + "final_average_score": avg_score + }) + + print(f"Taux de succès du modèle A2C sur {num_episodes} épisodes : {success_rate:.2f}%") + print(f"Score moyen : {avg_score:.2f}") + + + video_folder = "./videos/" + eval_env = VecVideoRecorder(eval_env, video_folder, record_video_trigger=lambda x: x == 0, video_length=1000) + + obs = eval_env.reset() + for _ in range(1000): + action, _ = model.predict(obs) + obs, _, _, _ = eval_env.step(action) + + eval_env.close() + + + package_to_hub( + model=model, + model_name="a2c_cartpole", + model_architecture="A2C", + env_id="CartPole-v1", + eval_env=eval_env, + repo_id="MaximeCerise/a2c_cartpole", + commit_message="add a2c with evaluation" + ) + -package_to_hub( - model=model, - model_name="a2c_cartpole", - model_architecture="A2C", - env_id="CartPole-v1", - eval_env=eval_env, - repo_id="MaximeCerise/a2c_cartpole", - commit_message="add a2c" -) -wandb.finish() + wandb.finish() +if __name__ == "__main__": + a2c_sb3() \ No newline at end of file diff --git a/videos/rl-video-step-0-to-step-1000.mp4 b/videos/rl-video-step-0-to-step-1000.mp4 index 98f6318f63cb99401362e83ad589912a7b61629f..a63963d2e4454c2dbe3f2f1729016ffe980854c2 100644 Binary files a/videos/rl-video-step-0-to-step-1000.mp4 and b/videos/rl-video-step-0-to-step-1000.mp4 differ