Skip to content
Snippets Groups Projects
Commit bb31e1b0 authored by Brussart Paul-emile's avatar Brussart Paul-emile
Browse files

Adding a2c_sb3_panda_reach.py

parent 25c3cb07
Branches
No related tags found
No related merge requests found
import gym
import panda_gym
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
import wandb
from wandb.integration.sb3 import WandbCallback
# Define a dictionary to store the configuration for the experiment
config = {
"policy_type": "MultiInputPolicy", # Specify the type of policy to be used
"total_timesteps": 500000, # Total number of timesteps for training
"env_name": "PandaReachJointsDense-v2", # Name of the environment to be used
}
# Initialize the W&B run with the specified project and configuration
run = wandb.init(
project="pandareach",
config=config,
sync_tensorboard=True,
monitor_gym=True,
save_code=True,
)
# Define a function to create the environment
def make_env():
env = gym.make(config["env_name"]) # Create the environment using the specified name
env = Monitor(env) # Wrap the environment in a Monitor to record various metrics
return env
# Create a vectorized environment using the make_env function
env = DummyVecEnv([make_env])
# Initialize the A2C model with the specified policy type and environment
model = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
# Train the model for the specified number of timesteps
model.learn(
total_timesteps=config["total_timesteps"],
callback=WandbCallback(
gradient_save_freq=10000, # Save the gradients every 10000 timesteps
model_save_path=f"models/{run.id}", # Save the model with the specified file path
verbose=2, # Specify the verbosity level for the WandbCallback
)
)
model.save("a2c_sb3_panda_reach")
# Finish the W&B run
run.finish()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment