diff --git a/a2c_sb3_panda_reach.py b/a2c_sb3_panda_reach.py new file mode 100644 index 0000000000000000000000000000000000000000..01f7163c06d026df0f873669f471d96cff810d3a --- /dev/null +++ b/a2c_sb3_panda_reach.py @@ -0,0 +1,50 @@ +import gym +import panda_gym +from stable_baselines3 import A2C +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.vec_env import DummyVecEnv +import wandb +from wandb.integration.sb3 import WandbCallback + +# Define a dictionary to store the configuration for the experiment +config = { + "policy_type": "MultiInputPolicy", # Specify the type of policy to be used + "total_timesteps": 500000, # Total number of timesteps for training + "env_name": "PandaReachJointsDense-v2", # Name of the environment to be used +} + +# Initialize the W&B run with the specified project and configuration +run = wandb.init( + project="pandareach", + config=config, + sync_tensorboard=True, + monitor_gym=True, + save_code=True, +) + +# Define a function to create the environment +def make_env(): + env = gym.make(config["env_name"]) # Create the environment using the specified name + env = Monitor(env) # Wrap the environment in a Monitor to record various metrics + return env + +# Create a vectorized environment using the make_env function +env = DummyVecEnv([make_env]) + +# Initialize the A2C model with the specified policy type and environment +model = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}") + +# Train the model for the specified number of timesteps +model.learn( + total_timesteps=config["total_timesteps"], + callback=WandbCallback( + gradient_save_freq=10000, # Save the gradients every 10000 timesteps + model_save_path=f"models/{run.id}", # Save the model with the specified file path + verbose=2, # Specify the verbosity level for the WandbCallback + ) +) + +model.save("a2c_sb3_panda_reach") + +# Finish the W&B run +run.finish()