diff --git a/a2c_sb3_panda_reach.py b/a2c_sb3_panda_reach.py index ea9d1d39c9687710f6f7dbbc009f2d593dca3b9c..a2cae910b3056dcb7c514e6964cd6eb742b144f5 100644 --- a/a2c_sb3_panda_reach.py +++ b/a2c_sb3_panda_reach.py @@ -1,4 +1,4 @@ -import gym +import gymnasium as gym from stable_baselines3 import A2C from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv @@ -8,7 +8,6 @@ from wandb.integration.sb3 import WandbCallback from huggingface_hub import login from huggingface_sb3 import push_to_hub - if __name__ == "__main__": # Log in HF login() @@ -21,13 +20,22 @@ if __name__ == "__main__": "env_name": "PandaReachJointsDense-v3", } - # WB initialization + # Initialize a new wandb run run = wandb.init( project="a2c_sb3_panda_reach", + config=config, sync_tensorboard=True, monitor_gym=True, ) + def make_env(): + env = gym.make(config["env_name"]) + env = Monitor(env) # Record stats such as returns + return env + + env = DummyVecEnv([make_env]) + model = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}") + # WB callback wandb_callback = WandbCallback( gradient_save_freq=100, @@ -35,20 +43,13 @@ if __name__ == "__main__": verbose=2, ) - env = gym.make("PandaReachJointsDense-v3") - - model = A2C("MultiInputPolicy", - env, - verbose=1, - tensorboard_log=f"runs/{run.id}" - ) - model.learn( - total_timesteps=500_000, + total_timesteps=config["total_timesteps"], callback=wandb_callback ) model.save("PandaReachJointsDense_1.zip") - + + # Finish the run run.finish() # Upload on HF @@ -56,4 +57,4 @@ if __name__ == "__main__": repo_id="CorentinGst/PandaReachJointsDense_1", filename="PandaReachJointsDense_1.zip", commit_message="Add my 1st model trained on PandaReachJointsDense-v3 env", - ) \ No newline at end of file + )