Skip to content
Snippets Groups Projects
Select Git revision
  • e5b01982650b55893e5009fb4834cc0ec6545db9
  • master default protected
2 results

selection.py

Blame
  • Forked from Vuillemot Romain / INF-TC1
    Source project has a limited visibility.
    a2c_sb3_cartpole.py 606 B
    import gymnasium as gym
    from stable_baselines3 import A2C
    from stable_baselines3.common.env_util import make_vec_env
    
    # Create and wrap the environment
    env_id = "CartPole-v1"
    env = make_vec_env(env_id, n_envs=1)
    
    # Initialize the A2C agent
    model = A2C('MlpPolicy', env, verbose=1)
    
    # Train the agent
    model.learn(total_timesteps=10000)
    
    # Save the trained model
    model.save("a2c_sb3_cartpole")
    
    # Evaluate the trained agent
    obs = env.reset()
    for _ in range(1000):
        action, _states = model.predict(obs)
        obs, rewards, dones, info = env.step(action)
        env.render()
    
    # Close the environment
    env.close()