Skip to content
Snippets Groups Projects
Commit 7c2027ba authored by cgerest's avatar cgerest
Browse files

Update full workflow

parent 280fe978
Branches
Tags
No related merge requests found
import gym import gymnasium as gym
from stable_baselines3 import A2C from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv from stable_baselines3.common.vec_env import DummyVecEnv
...@@ -8,7 +8,6 @@ from wandb.integration.sb3 import WandbCallback ...@@ -8,7 +8,6 @@ from wandb.integration.sb3 import WandbCallback
from huggingface_hub import login from huggingface_hub import login
from huggingface_sb3 import push_to_hub from huggingface_sb3 import push_to_hub
if __name__ == "__main__": if __name__ == "__main__":
# Log in HF # Log in HF
login() login()
...@@ -21,13 +20,22 @@ if __name__ == "__main__": ...@@ -21,13 +20,22 @@ if __name__ == "__main__":
"env_name": "PandaReachJointsDense-v3", "env_name": "PandaReachJointsDense-v3",
} }
# WB initialization # Initialize a new wandb run
run = wandb.init( run = wandb.init(
project="a2c_sb3_panda_reach", project="a2c_sb3_panda_reach",
config=config,
sync_tensorboard=True, sync_tensorboard=True,
monitor_gym=True, monitor_gym=True,
) )
def make_env():
env = gym.make(config["env_name"])
env = Monitor(env) # Record stats such as returns
return env
env = DummyVecEnv([make_env])
model = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
# WB callback # WB callback
wandb_callback = WandbCallback( wandb_callback = WandbCallback(
gradient_save_freq=100, gradient_save_freq=100,
...@@ -35,20 +43,13 @@ if __name__ == "__main__": ...@@ -35,20 +43,13 @@ if __name__ == "__main__":
verbose=2, verbose=2,
) )
env = gym.make("PandaReachJointsDense-v3")
model = A2C("MultiInputPolicy",
env,
verbose=1,
tensorboard_log=f"runs/{run.id}"
)
model.learn( model.learn(
total_timesteps=500_000, total_timesteps=config["total_timesteps"],
callback=wandb_callback callback=wandb_callback
) )
model.save("PandaReachJointsDense_1.zip") model.save("PandaReachJointsDense_1.zip")
# Finish the run
run.finish() run.finish()
# Upload on HF # Upload on HF
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment