Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • main
1 result

Target

Select target project
  • moreaum/hands-on-rl
1 result
Select Git revision
  • main
1 result
Show changes
Commits on Source (2)
......@@ -25,6 +25,4 @@ We finally have an evaluation with 100% of sucess:
Here we set up a complete pipeline to solve Cartpole environment with A2C algorithm.
Wandb has been set up to follow the learning phase.
https://wandb.ai/maximecerise-ecl/cartpole-a2c
![alt text](saves/rollout.png)
Wandb has been set up to track the learning phase : [WandB tacking](https://wandb.ai/maximecerise-ecl/cartpole-a2c?nw=nwusermaximecerise)
import wandb
import gymnasium as gym
import numpy as np
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
from huggingface_sb3 import package_to_hub
def a2c_sb3():
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = Monitor(env)
env = DummyVecEnv([lambda: env])
wandb.init(
entity="maximecerise-ecl",
project="cartpole-a2c",
......@@ -21,15 +21,57 @@ wandb.init(
model = A2C("MlpPolicy", env, verbose=1, tensorboard_log="./a2c_tensorboard/")
model.learn(total_timesteps=500000)
model.learn(total_timesteps=300000)
model.save("a2c_cartpole")
env.close()
eval_env = gym.make("CartPole-v1", render_mode="rgb_array")
eval_env = Monitor(eval_env)
eval_env = DummyVecEnv([lambda: eval_env])
success_count = 0
num_episodes = 100
scores = []
for episode in range(num_episodes):
obs = eval_env.reset()
done = False
episode_reward = 0
while not done:
action, _ = model.predict(obs)
obs, reward, done, info = eval_env.step(action)
done = done[0] # Extraction de la valeur booléenne
episode_reward += reward
scores.append(episode_reward)
if episode_reward >= 200:
success_count += 1
wandb.log({
"episode": episode,
"episode_reward": episode_reward,
"success_rate (%)": success_count / (episode + 1) * 100
})
success_rate = success_count / num_episodes * 100
avg_score = np.mean(scores)
wandb.log({
"final_success_rate (%)": success_rate,
"final_average_score": avg_score
})
print(f"Taux de succès du modèle A2C sur {num_episodes} épisodes : {success_rate:.2f}%")
print(f"Score moyen : {avg_score:.2f}")
video_folder = "./videos/"
eval_env = VecVideoRecorder(eval_env, video_folder, record_video_trigger=lambda x: x == 0, video_length=1000)
......@@ -40,6 +82,7 @@ for _ in range(1000):
eval_env.close()
package_to_hub(
model=model,
model_name="a2c_cartpole",
......@@ -47,6 +90,10 @@ package_to_hub(
env_id="CartPole-v1",
eval_env=eval_env,
repo_id="MaximeCerise/a2c_cartpole",
commit_message="add a2c"
commit_message="add a2c with evaluation"
)
wandb.finish()
if __name__ == "__main__":
a2c_sb3()
\ No newline at end of file
import gym
import gymnasium as gym
import torch
import numpy as np
import torch
......
No preview for this file type