Skip to content
Snippets Groups Projects
Commit 25c3cb07 authored by Brussart Paul-emile's avatar Brussart Paul-emile
Browse files

Adding a2c_sb3_cartpole.py using wandb to test the model

parent 8b00c644
No related branches found
No related tags found
No related merge requests found
import gym
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
import wandb
from wandb.integration.sb3 import WandbCallback
config = {
"policy_type": "MlpPolicy",
"total_timesteps": 20000,
"env_name": "CartPole-v1",
}
run = wandb.init(
project="cartpole",
config=config,
sync_tensorboard=True,
monitor_gym=True,
save_code=True,
)
# Create the CartPole environment
env = gym.make('CartPole-v1')
def make_env():
env = gym.make(config["env_name"])
env = Monitor(env) # record stats such as returns
return env
env = DummyVecEnv([make_env])
# Wrap the environment in a DummyVecEnv to handle multiple environments
env = DummyVecEnv([lambda: env])
# Initialize the A2C model
model = A2C('MlpPolicy', env, verbose=1)
# Train the model for 1000 steps
model.learn(total_timesteps=1000)
# Train the model for 20 000 steps
model.learn(
total_timesteps=config["total_timesteps"],
callback=WandbCallback(
gradient_save_freq=100,
model_save_path=f"models/{run.id}",
verbose=2,
)
)
#Saving the model
model.save("a2c_sb3_cartpole")
# Test the trained model
obs = env.reset()
for i in range(1000):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
\ No newline at end of file
model.save("a2c_sb3_cartpole_model")
run.finish()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment