diff --git a/a2c_sb3_panda_reach.py b/a2c_sb3_panda_reach.py new file mode 100644 index 0000000000000000000000000000000000000000..b8d6f9baf207fe9979e568dd2d85bb94e817ad06 --- /dev/null +++ b/a2c_sb3_panda_reach.py @@ -0,0 +1,31 @@ +import wandb +import gym +import panda_gym +from stable_baselines3 import A2C +import numpy as np + +wandb.init(project='panda-reach-joints-dense-v2') +env = gym.make('PandaReachJointsDense-v2') +env = gym.make('PandaReachJointsDense-v2', render_mode="human") + +model = A2C("MlpPolicy", env, verbose=1) +model.learn(total_timesteps=500000) +rewards = [] +obs = env.reset() +while True: + action, _states = model.predict(obs) + obs, reward, done, info = env.step(action) + rewards.append(reward) + if done: + break + +wandb.log({'rewards': np.sum(rewards)}) +model.save("panda-reach-joints-dense-v2") + +from huggingface_sb3 import push_to_hub + +push_to_hub( + repo_id="manonghelfi/panda-reach-joints-dense-v2", + filename="./panda-reach-joints-dense-v2.zip", + commit_message="Added panda-reach-joints-dense-v2 model trained with A2C", +)