Skip to content
Snippets Groups Projects
Commit 8b00c644 authored by Brussart Paul-emile's avatar Brussart Paul-emile
Browse files

Adding a2c_sb3_cartpole.py using stable_baselines3

parent 69ad3981
No related branches found
No related tags found
No related merge requests found
import gym
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import DummyVecEnv
# Create the CartPole environment
env = gym.make('CartPole-v1')
# Wrap the environment in a DummyVecEnv to handle multiple environments
env = DummyVecEnv([lambda: env])
# Initialize the A2C model
model = A2C('MlpPolicy', env, verbose=1)
# Train the model for 1000 steps
model.learn(total_timesteps=1000)
#Saving the model
model.save("a2c_sb3_cartpole")
# Test the trained model
obs = env.reset()
for i in range(1000):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
\ No newline at end of file
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment