Skip to content
Snippets Groups Projects
Select Git revision
  • b5a9c00f66700d637b348522b0a1ec6a1f010686
  • master default protected
2 results

bfs.py

Blame
  • Forked from Vuillemot Romain / INF-TC1
    Source project has a limited visibility.
    a2c_sb3_cartpole.py 606 B
    import gymnasium as gym
    from stable_baselines3 import A2C
    from stable_baselines3.common.env_util import make_vec_env
    
    # Create and wrap the environment
    env_id = "CartPole-v1"
    env = make_vec_env(env_id, n_envs=1)
    
    # Initialize the A2C agent
    model = A2C('MlpPolicy', env, verbose=1)
    
    # Train the agent
    model.learn(total_timesteps=10000)
    
    # Save the trained model
    model.save("a2c_sb3_cartpole")
    
    # Evaluate the trained agent
    obs = env.reset()
    for _ in range(1000):
        action, _states = model.predict(obs)
        obs, rewards, dones, info = env.step(action)
        env.render()
    
    # Close the environment
    env.close()