Skip to content
Snippets Groups Projects
Select Git revision
  • bcf7757e0047cad589c5ae395e0fb806b7a42f4d
  • main default protected
2 results

README.md

Blame
  • Forked from Dellandrea Emmanuel / MOD_4_6-TD2
    Source project has a limited visibility.
    a2c_sb3_cartpole.py 1.87 KiB
    import gymnasium as gym
    import numpy as np
    from stable_baselines3.common.evaluation import evaluate_policy
    from stable_baselines3 import A2C
    from huggingface_sb3 import push_to_hub
    from huggingface_hub import login
    
    
    
    print(f"{gym.__version__=}")
    
    
    env = gym.make("CartPole-v1", render_mode="rgb_array")
    model = A2C("MlpPolicy", env, verbose=1)
    
    def evaluate(model, num_episodes=100, deterministic=True):
      
        vec_env = model.get_env()
        all_episode_rewards = []
        for i in range(num_episodes):
            episode_rewards = []
            done = False
            obs = vec_env.reset()
            while not done:
                # _states are only useful when using LSTM policies
                action, _states = model.predict(obs, deterministic=deterministic)
                # here, action, rewards and dones are arrays
                # also note that the step only returns a 4-tuple, as the env that is returned
                obs, reward, done, info = vec_env.step(action)
                episode_rewards.append(reward)
    
            all_episode_rewards.append(sum(episode_rewards))
    
        mean_episode_reward = np.mean(all_episode_rewards)
        print("Mean reward:", mean_episode_reward, "Num episodes:", num_episodes)
    
        return mean_episode_reward
    
    # Use a separate environement for evaluation
    eval_env = gym.make("CartPole-v1", render_mode="rgb_array")
    
    # Train the agent for 10000 steps
    model.learn(total_timesteps=10_000)
    
    # Evaluate the trained agent
    mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=100)
    
    print(f"mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}")
    
    login(token="****************")
    
    # Save the trained model
    model.save("ECL-TD-RL1-a2c_cartpole.zip")
    
    # Load the trained model
    model = A2C.load("ECL-TD-RL1-a2c_cartpole.zip")
    
    push_to_hub(
        repo_id="Karim-20/a2c_cartpole",
        filename="ECL-TD-RL1-a2c_cartpole.zip",
        commit_message="Add cartepole-v1 environement, agent used to train is A2C"
    )