Skip to content
Snippets Groups Projects
Commit 7c2027ba authored by cgerest's avatar cgerest
Browse files

Update full workflow

parent 280fe978
No related branches found
No related tags found
No related merge requests found
import gym
import gymnasium as gym
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
......@@ -8,7 +8,6 @@ from wandb.integration.sb3 import WandbCallback
from huggingface_hub import login
from huggingface_sb3 import push_to_hub
if __name__ == "__main__":
# Log in HF
login()
......@@ -21,13 +20,22 @@ if __name__ == "__main__":
"env_name": "PandaReachJointsDense-v3",
}
# WB initialization
# Initialize a new wandb run
run = wandb.init(
project="a2c_sb3_panda_reach",
config=config,
sync_tensorboard=True,
monitor_gym=True,
)
def make_env():
env = gym.make(config["env_name"])
env = Monitor(env) # Record stats such as returns
return env
env = DummyVecEnv([make_env])
model = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
# WB callback
wandb_callback = WandbCallback(
gradient_save_freq=100,
......@@ -35,20 +43,13 @@ if __name__ == "__main__":
verbose=2,
)
env = gym.make("PandaReachJointsDense-v3")
model = A2C("MultiInputPolicy",
env,
verbose=1,
tensorboard_log=f"runs/{run.id}"
)
model.learn(
total_timesteps=500_000,
total_timesteps=config["total_timesteps"],
callback=wandb_callback
)
model.save("PandaReachJointsDense_1.zip")
# Finish the run
run.finish()
# Upload on HF
......@@ -56,4 +57,4 @@ if __name__ == "__main__":
repo_id="CorentinGst/PandaReachJointsDense_1",
filename="PandaReachJointsDense_1.zip",
commit_message="Add my 1st model trained on PandaReachJointsDense-v3 env",
)
\ No newline at end of file
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment