Select Git revision
a2c_sb3_panda_reach.py
Forked from
Dellandrea Emmanuel / MSO_3_4-TD1
8 commits behind, 1 commit ahead of the upstream repository.
-
Paganelli Emilien authoredPaganelli Emilien authored
a2c_sb3_panda_reach.py 1.26 KiB
import wandb
import gymnasium as gym
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
from wandb.integration.sb3 import WandbCallback
import panda_gym
# Initialize Weights & Biases
config = {
"policy_type": "MultiInputPolicy", # Type de politique
"total_timesteps": 500000, # Nombre total de pas de temps d'entraînement
"env_name": "PandaReach-v3", # Nom correct de l'environnement
}
run= wandb.init(project="panda-gym-training",
config=config,
sync_tensorboard=True,
save_code=True,
)
def make_env():
env = gym.make(config["env_name"])
return env
env = DummyVecEnv([make_env])
env = VecVideoRecorder(
env,
f"videos/{run.id}",
record_video_trigger=lambda x: x % 50000 == 0,
video_length=200,
)
model = A2C(config["policy_type"], env, verbose=1)
model.learn(total_timesteps=config["total_timesteps"], callback=WandbCallback())
model.save("a2c_panda_reach_model")
nom_artefact = "a2c_panda_reach_model"
with wandb.init(project="panda-gym-training", job_type="upload") as run:
artefact = wandb.Artifact(name=nom_artefact, type="model")
artefact.add_file("a2c_panda_reach_model.zip") # Ajouter le fichier du modèle entraîné
run.log_artifact(artefact)