From 789865bb9d881482565a3da03d98706b5a58bc10 Mon Sep 17 00:00:00 2001
From: Majdi Karim <karim.majdi@etu.ec-lyon.fr>
Date: Tue, 5 Mar 2024 21:39:25 +0000
Subject: [PATCH] Add new file

---
 a2c_sb3_panda_reach.py | 65 ++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 65 insertions(+)
 create mode 100644 a2c_sb3_panda_reach.py

diff --git a/a2c_sb3_panda_reach.py b/a2c_sb3_panda_reach.py
new file mode 100644
index 0000000..380bb58
--- /dev/null
+++ b/a2c_sb3_panda_reach.py
@@ -0,0 +1,65 @@
+### LIBRARIES
+
+import gymnasium as gym
+from stable_baselines3 import A2C
+from stable_baselines3.common.monitor import Monitor
+from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
+import wandb
+from wandb.integration.sb3 import WandbCallback
+from huggingface_sb3 import push_to_hub
+import panda_gym
+import os
+from huggingface_hub import login
+
+
+
+#dir_path = os.path.dirname(os.path.realpath(__file__))
+#os.chdir(dir_path)
+
+config = {
+    "policy_type": "MultiInputPolicy",
+    "total_timesteps": 250000,
+    "env_name": "PandaReachJointsDense-v3",
+}
+
+run = wandb.init(
+    project="sb3-panda-reach",
+    config=config,
+    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
+    monitor_gym=True,  # auto-upload the videos of agents playing the game
+    save_code=True,  # optional
+)
+
+def make_env():
+    env = gym.make(config["env_name"])
+    env = Monitor(env)  # record stats such as returns
+    return env
+
+env = DummyVecEnv([make_env])
+# env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % 2000 == 0, video_length=200)
+model = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
+model.learn(
+    total_timesteps=config["total_timesteps"],
+    callback=WandbCallback(
+        gradient_save_freq=100,
+        model_save_path=f"models/{run.id}",
+        verbose=2,
+    ),
+)
+
+run.finish()
+
+login(token="*********")
+
+
+# Save the trained model
+model.save("ECL-TD-RL1-a2c_panda_reach.zip")
+
+# Load the trained model
+model = A2C.load("ECL-TD-RL1-a2c_panda_reach.zip")
+
+push_to_hub(
+    repo_id="Karim-20/a2c_cartpole",
+    filename="ECL-TD-RL1-a2c_panda_reach.zip",
+    commit_message="Add PandaReachJointsDense-v2 environement, agent used to train is A2C"
+)
-- 
GitLab