diff --git a/a2c_sb3_cartpole.py b/a2c_sb3_cartpole.py new file mode 100644 index 0000000000000000000000000000000000000000..3bcd8ae5c62b0953a733c9ab2a12779686101dec --- /dev/null +++ b/a2c_sb3_cartpole.py @@ -0,0 +1,17 @@ +import gym + +from stable_baselines3 import A2C + +env = gym.make("CartPole-v1") + +model = A2C("MlpPolicy", env, verbose=1) +model.learn(total_timesteps=10000) +model.save("a2c_sb3_cartpole") + +vec_env = model.get_env() +obs = vec_env.reset() +for i in range(1000): + action, _state = model.predict(obs, deterministic=True) + obs, reward, done, info = vec_env.step(action) + vec_env.render() +