Select Git revision
bfs-tree.py
Forked from
Vuillemot Romain / INF-TC1
Source project has a limited visibility.
-
Romain Vuillemot authoredRomain Vuillemot authored
training_wandb.py 1.58 KiB
import wandb
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import BaseCallback
import gymnasium as gym
# Initialize W&B run
wandb.init(
project="cartpole-experiments",
entity="hamza-masmoudi-central-lyon",
name="a2c-cartpole-run-1",
config={
"algorithm": "A2C",
"environment": "CartPole-v1",
"learning_rate": 0.0007,
"n_envs": 1,
"total_timesteps": 10000
}
)
# Create vectorized environment
env = make_vec_env("CartPole-v1", n_envs=1)
# Custom callback to log episode rewards
class EpisodeRewardCallback(BaseCallback):
def __init__(self, verbose=0):
super(EpisodeRewardCallback, self).__init__(verbose)
self.episode_rewards = []
def _on_step(self) -> bool:
# Access the episode reward from the environment's info
if len(self.locals['infos']) > 0 and 'episode' in self.locals['infos'][0]:
episode_info = self.locals['infos'][0].get('episode', {})
episode_reward = episode_info.get('r')
if episode_reward is not None:
wandb.log({"episode_reward": episode_reward})
self.episode_rewards.append(episode_reward)
return True
# Initialize model
model = A2C("MlpPolicy", env, verbose=1)
# Train model with custom callback
model.learn(
total_timesteps=10000,
callback=EpisodeRewardCallback()
)
# Save the model
model.save("a2c_sb3_cartpole.zip")
wandb.save("a2c_sb3_cartpole.zip")
# Close environment
env.close()
# Finish the W&B run
wandb.finish()