From 789865bb9d881482565a3da03d98706b5a58bc10 Mon Sep 17 00:00:00 2001 From: Majdi Karim <karim.majdi@etu.ec-lyon.fr> Date: Tue, 5 Mar 2024 21:39:25 +0000 Subject: [PATCH] Add new file --- a2c_sb3_panda_reach.py | 65 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 a2c_sb3_panda_reach.py diff --git a/a2c_sb3_panda_reach.py b/a2c_sb3_panda_reach.py new file mode 100644 index 0000000..380bb58 --- /dev/null +++ b/a2c_sb3_panda_reach.py @@ -0,0 +1,65 @@ +### LIBRARIES + +import gymnasium as gym +from stable_baselines3 import A2C +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder +import wandb +from wandb.integration.sb3 import WandbCallback +from huggingface_sb3 import push_to_hub +import panda_gym +import os +from huggingface_hub import login + + + +#dir_path = os.path.dirname(os.path.realpath(__file__)) +#os.chdir(dir_path) + +config = { + "policy_type": "MultiInputPolicy", + "total_timesteps": 250000, + "env_name": "PandaReachJointsDense-v3", +} + +run = wandb.init( + project="sb3-panda-reach", + config=config, + sync_tensorboard=True, # auto-upload sb3's tensorboard metrics + monitor_gym=True, # auto-upload the videos of agents playing the game + save_code=True, # optional +) + +def make_env(): + env = gym.make(config["env_name"]) + env = Monitor(env) # record stats such as returns + return env + +env = DummyVecEnv([make_env]) +# env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % 2000 == 0, video_length=200) +model = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}") +model.learn( + total_timesteps=config["total_timesteps"], + callback=WandbCallback( + gradient_save_freq=100, + model_save_path=f"models/{run.id}", + verbose=2, + ), +) + +run.finish() + +login(token="*********") + + +# Save the trained model +model.save("ECL-TD-RL1-a2c_panda_reach.zip") + +# Load the trained model +model = A2C.load("ECL-TD-RL1-a2c_panda_reach.zip") + +push_to_hub( + repo_id="Karim-20/a2c_cartpole", + filename="ECL-TD-RL1-a2c_panda_reach.zip", + commit_message="Add PandaReachJointsDense-v2 environement, agent used to train is A2C" +) -- GitLab