diff --git a/a2c_sb3_cartpole.py b/a2c_sb3_cartpole.py index 2ec8cf2f9e908f5632d15140690b5a7590ea03f0..efbce452f8e1a8e82f6a9e699d81da17135ef4a1 100644 --- a/a2c_sb3_cartpole.py +++ b/a2c_sb3_cartpole.py @@ -1,25 +1,50 @@ import gym from stable_baselines3 import A2C +from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv +import wandb +from wandb.integration.sb3 import WandbCallback + + +config = { + "policy_type": "MlpPolicy", + "total_timesteps": 20000, + "env_name": "CartPole-v1", +} + +run = wandb.init( + project="cartpole", + config=config, + sync_tensorboard=True, + monitor_gym=True, + save_code=True, +) + # Create the CartPole environment -env = gym.make('CartPole-v1') +def make_env(): + env = gym.make(config["env_name"]) + env = Monitor(env) # record stats such as returns + return env + +env = DummyVecEnv([make_env]) -# Wrap the environment in a DummyVecEnv to handle multiple environments -env = DummyVecEnv([lambda: env]) # Initialize the A2C model model = A2C('MlpPolicy', env, verbose=1) -# Train the model for 1000 steps -model.learn(total_timesteps=1000) +# Train the model for 20 000 steps +model.learn( + total_timesteps=config["total_timesteps"], + callback=WandbCallback( + gradient_save_freq=100, + model_save_path=f"models/{run.id}", + verbose=2, + ) +) #Saving the model -model.save("a2c_sb3_cartpole") - -# Test the trained model -obs = env.reset() -for i in range(1000): - action, _states = model.predict(obs) - obs, rewards, dones, info = env.step(action) - env.render() \ No newline at end of file +model.save("a2c_sb3_cartpole_model") + +run.finish() +