From 9d7b1ba981fb0901447536eb839d898ffd7e6542 Mon Sep 17 00:00:00 2001 From: Ghelfi Manon <manon.ghelfi@ecl19.ec-lyon.fr> Date: Wed, 8 Feb 2023 15:00:31 +0000 Subject: [PATCH] Upload New File --- a2c_sb3_panda_reach.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 a2c_sb3_panda_reach.py diff --git a/a2c_sb3_panda_reach.py b/a2c_sb3_panda_reach.py new file mode 100644 index 0000000..b8d6f9b --- /dev/null +++ b/a2c_sb3_panda_reach.py @@ -0,0 +1,31 @@ +import wandb +import gym +import panda_gym +from stable_baselines3 import A2C +import numpy as np + +wandb.init(project='panda-reach-joints-dense-v2') +env = gym.make('PandaReachJointsDense-v2') +env = gym.make('PandaReachJointsDense-v2', render_mode="human") + +model = A2C("MlpPolicy", env, verbose=1) +model.learn(total_timesteps=500000) +rewards = [] +obs = env.reset() +while True: + action, _states = model.predict(obs) + obs, reward, done, info = env.step(action) + rewards.append(reward) + if done: + break + +wandb.log({'rewards': np.sum(rewards)}) +model.save("panda-reach-joints-dense-v2") + +from huggingface_sb3 import push_to_hub + +push_to_hub( + repo_id="manonghelfi/panda-reach-joints-dense-v2", + filename="./panda-reach-joints-dense-v2.zip", + commit_message="Added panda-reach-joints-dense-v2 model trained with A2C", +) -- GitLab