diff --git a/a2c_sb3_cartpole.py b/a2c_sb3_cartpole.py new file mode 100644 index 0000000000000000000000000000000000000000..50f784709d629804612a5849af456d3b1e5d60f3 --- /dev/null +++ b/a2c_sb3_cartpole.py @@ -0,0 +1,52 @@ +import wandb +import gymnasium as gym +from stable_baselines3 import A2C +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder +from huggingface_sb3 import package_to_hub + + +env = gym.make("CartPole-v1", render_mode="rgb_array") +env = Monitor(env) +env = DummyVecEnv([lambda: env]) + + +wandb.init( + entity="maximecerise-ecl", + project="cartpole-a2c_", + sync_tensorboard=True, + monitor_gym=True, + save_code=True + ) + + +model = A2C("MlpPolicy", env, verbose=1) +model.learn(total_timesteps=5000) + + +model.save("a2c_cartpole") +env.close() + +eval_env = gym.make("CartPole-v1", render_mode="rgb_array") +eval_env = Monitor(eval_env) +eval_env = DummyVecEnv([lambda: eval_env]) +video_folder = "./videos/" +eval_env = VecVideoRecorder(eval_env, video_folder, record_video_trigger=lambda x: x == 0, video_length=1000) + +obs = eval_env.reset() +for _ in range(1000): + action, _ = model.predict(obs) + obs, _, _, _ = eval_env.step(action) + +eval_env.close() + +package_to_hub( + model=model, + model_name="a2c_cartpole", + model_architecture="A2C", + env_id="CartPole-v1", + eval_env=eval_env, + repo_id="MaximeCerise/a2c_cartpole", + commit_message="add a2c" +) +wandb.finish() diff --git a/requirements.txt b/requirements.txt index c00a057237e746f1270c148a1b24297de437b0ac..e3fe496158151900fa87537ad0b7bb2b4c79f7bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,11 @@ gymnasium -gymnasium[classic-control] +#gymnasium[classic-control] torch - pandas numpy==1.26 -matplotlib \ No newline at end of file +matplotlib +stable_baselines3 +huggingface_sb3 +#stable-baselines3[extra] +moviepy +wandb \ No newline at end of file diff --git a/videos/rl-video-step-0-to-step-1000.mp4 b/videos/rl-video-step-0-to-step-1000.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..93ec3d2d9597f8803bbcc193c167c8779a59dea8 Binary files /dev/null and b/videos/rl-video-step-0-to-step-1000.mp4 differ