import wandb import gymnasium as gym import numpy as np from stable_baselines3 import A2C from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder from huggingface_sb3 import package_to_hub def a2c_sb3(): env = gym.make("CartPole-v1", render_mode="rgb_array") env = Monitor(env) env = DummyVecEnv([lambda: env]) wandb.init( entity="maximecerise-ecl", project="cartpole-a2c", sync_tensorboard=True, monitor_gym=True, save_code=True ) model = A2C("MlpPolicy", env, verbose=1, tensorboard_log="./a2c_tensorboard/") model.learn(total_timesteps=300000) model.save("a2c_cartpole") env.close() eval_env = gym.make("CartPole-v1", render_mode="rgb_array") eval_env = Monitor(eval_env) eval_env = DummyVecEnv([lambda: eval_env]) success_count = 0 num_episodes = 100 scores = [] for episode in range(num_episodes): obs = eval_env.reset() done = False episode_reward = 0 while not done: action, _ = model.predict(obs) obs, reward, done, info = eval_env.step(action) done = done[0] # Extraction de la valeur booléenne episode_reward += reward scores.append(episode_reward) if episode_reward >= 200: success_count += 1 wandb.log({ "episode": episode, "episode_reward": episode_reward, "success_rate (%)": success_count / (episode + 1) * 100 }) success_rate = success_count / num_episodes * 100 avg_score = np.mean(scores) wandb.log({ "final_success_rate (%)": success_rate, "final_average_score": avg_score }) print(f"Taux de succès du modèle A2C sur {num_episodes} épisodes : {success_rate:.2f}%") print(f"Score moyen : {avg_score:.2f}") video_folder = "./videos/" eval_env = VecVideoRecorder(eval_env, video_folder, record_video_trigger=lambda x: x == 0, video_length=1000) obs = eval_env.reset() for _ in range(1000): action, _ = model.predict(obs) obs, _, _, _ = eval_env.step(action) eval_env.close() package_to_hub( model=model, model_name="a2c_cartpole", model_architecture="A2C", env_id="CartPole-v1", eval_env=eval_env, repo_id="MaximeCerise/a2c_cartpole", commit_message="add a2c with evaluation" ) wandb.finish() if __name__ == "__main__": a2c_sb3()