Skip to content
Snippets Groups Projects
Commit 15d811ac authored by MaximeCerise's avatar MaximeCerise
Browse files

a2c_sb3_cartpole.py

parent 07497c85
Branches
No related tags found
No related merge requests found
import wandb
import gymnasium as gym
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
from huggingface_sb3 import package_to_hub
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = Monitor(env)
env = DummyVecEnv([lambda: env])
wandb.init(
entity="maximecerise-ecl",
project="cartpole-a2c_",
sync_tensorboard=True,
monitor_gym=True,
save_code=True
)
model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=5000)
model.save("a2c_cartpole")
env.close()
eval_env = gym.make("CartPole-v1", render_mode="rgb_array")
eval_env = Monitor(eval_env)
eval_env = DummyVecEnv([lambda: eval_env])
video_folder = "./videos/"
eval_env = VecVideoRecorder(eval_env, video_folder, record_video_trigger=lambda x: x == 0, video_length=1000)
obs = eval_env.reset()
for _ in range(1000):
action, _ = model.predict(obs)
obs, _, _, _ = eval_env.step(action)
eval_env.close()
package_to_hub(
model=model,
model_name="a2c_cartpole",
model_architecture="A2C",
env_id="CartPole-v1",
eval_env=eval_env,
repo_id="MaximeCerise/a2c_cartpole",
commit_message="add a2c"
)
wandb.finish()
gymnasium gymnasium
gymnasium[classic-control] #gymnasium[classic-control]
torch torch
pandas pandas
numpy==1.26 numpy==1.26
matplotlib matplotlib
stable_baselines3
huggingface_sb3
#stable-baselines3[extra]
moviepy
wandb
\ No newline at end of file
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment