Skip to content
Snippets Groups Projects
Commit 4bc0941e authored by Amoussas Younes's avatar Amoussas Younes
Browse files

Replace a2c_sb3_panda_reach.py

parent 202c8657
Branches
No related tags found
No related merge requests found
...@@ -2,20 +2,22 @@ import wandb ...@@ -2,20 +2,22 @@ import wandb
wandb.init(project="PandaReachJointsDense-v2", sync_tensorboard=True, monitor_gym=True) wandb.init(project="PandaReachJointsDense-v2", sync_tensorboard=True, monitor_gym=True)
import gym import gym
import panda_gym
import torch import torch
from stable_baselines3 import A2C from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.monitor import Monitor
from huggingface_sb3 import push_to_hub
# Set up the CartPole environement # Set up the CartPole environement
# Create the environment # Create the environment
env = gym.make("PandaReachJointsDense-v2") env = gym.make("PandaReachJointsDense-v2")
#env = Monitor(env) env = Monitor(env)
# Reset the environment and get the initial observation # Reset the environment and get the initial observation
observation = env.reset() observation = env.reset()
model = A2C("MlpPolicy", env, verbose=1) model = A2C("MultiInputPolicy", env, verbose=1)
print(model) print(model)
model.learn(total_timesteps=500000) model.learn(total_timesteps=500000)
...@@ -36,3 +38,14 @@ for _ in range(500): ...@@ -36,3 +38,14 @@ for _ in range(500):
env.render() env.render()
wandb.finish() wandb.finish()
# Save the trained model
model.save("a2c_sb3_panda_reach.zip")
# Load the trained model
model = A2C.load("a2c_sb3_panda_reach.zip")
push_to_hub(
repo_id="Younes-hands-on-rl/a2c_sb3_panda_reach",
filename="a2c_sb3_panda_reach.zip",
commit_message="Commit model")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment