Skip to content
Snippets Groups Projects
Select Git revision
  • 550d435a2008fc0eef83cb0d30fdb53c32de264f
  • master default protected
2 results

bfs-tree.py

Blame
  • Forked from Vuillemot Romain / INF-TC1
    Source project has a limited visibility.
    training_wandb.py 1.58 KiB
    import wandb
    from stable_baselines3 import A2C
    from stable_baselines3.common.env_util import make_vec_env
    from stable_baselines3.common.callbacks import BaseCallback
    import gymnasium as gym
    
    # Initialize W&B run
    wandb.init(
        project="cartpole-experiments",
        entity="hamza-masmoudi-central-lyon",
        name="a2c-cartpole-run-1",
        config={
            "algorithm": "A2C",
            "environment": "CartPole-v1",
            "learning_rate": 0.0007,
            "n_envs": 1,
            "total_timesteps": 10000
        }
    )
    
    # Create vectorized environment
    env = make_vec_env("CartPole-v1", n_envs=1)
    
    # Custom callback to log episode rewards
    class EpisodeRewardCallback(BaseCallback):
        def __init__(self, verbose=0):
            super(EpisodeRewardCallback, self).__init__(verbose)
            self.episode_rewards = []
    
        def _on_step(self) -> bool:
            # Access the episode reward from the environment's info
            if len(self.locals['infos']) > 0 and 'episode' in self.locals['infos'][0]:
                episode_info = self.locals['infos'][0].get('episode', {})
                episode_reward = episode_info.get('r')
                if episode_reward is not None:
                    wandb.log({"episode_reward": episode_reward})
                    self.episode_rewards.append(episode_reward)
            return True
    
    # Initialize model
    model = A2C("MlpPolicy", env, verbose=1)
    
    # Train model with custom callback
    model.learn(
        total_timesteps=10000,
        callback=EpisodeRewardCallback()
    )
    
    # Save the model
    model.save("a2c_sb3_cartpole.zip")
    wandb.save("a2c_sb3_cartpole.zip")
    
    # Close environment
    env.close()
    
    # Finish the W&B run
    wandb.finish()