Skip to content
Snippets Groups Projects
Select Git revision
  • main default protected
  • epaganel-main-patch-63e8
2 results

a2c_sb3_panda_reach.py

Blame
  • Forked from Dellandrea Emmanuel / MSO_3_4-TD1
    8 commits behind, 1 commit ahead of the upstream repository.
    a2c_sb3_panda_reach.py 1.26 KiB
    import wandb
    import gymnasium as gym
    from stable_baselines3 import A2C
    from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
    from wandb.integration.sb3 import WandbCallback 
    import panda_gym
    
    # Initialize Weights & Biases
    config = {
        "policy_type": "MultiInputPolicy",  # Type de politique
        "total_timesteps": 500000,  # Nombre total de pas de temps d'entraînement
        "env_name": "PandaReach-v3",  # Nom correct de l'environnement
    }
    
    run= wandb.init(project="panda-gym-training",
        config=config,
        sync_tensorboard=True,  
        save_code=True, 
    )
    
    
    
    def make_env():
        env = gym.make(config["env_name"])
        return env
    
    env = DummyVecEnv([make_env])
    
    env = VecVideoRecorder(
        env,
        f"videos/{run.id}",
        record_video_trigger=lambda x: x % 50000 == 0,
        video_length=200,
    )
    
    
    model = A2C(config["policy_type"], env, verbose=1)
    
    
    model.learn(total_timesteps=config["total_timesteps"], callback=WandbCallback())
    
    
    model.save("a2c_panda_reach_model")
    
    
    nom_artefact = "a2c_panda_reach_model"
    with wandb.init(project="panda-gym-training", job_type="upload") as run:
        artefact = wandb.Artifact(name=nom_artefact, type="model")
        artefact.add_file("a2c_panda_reach_model.zip")  # Ajouter le fichier du modèle entraîné
        run.log_artifact(artefact)