Select Git revision
a2c_sb3_panda_reach.py
Forked from
Dellandrea Emmanuel / MSO_3_4-TD1
Source project has a limited visibility.
a2c_sb3_panda_reach.py 1.84 KiB
import gym
import panda_gym
from stable_baselines3 import A2C
from huggingface_sb3 import package_to_hub, push_to_hub
from gym import envs
from gymnasium.envs.registration import register
from tqdm import tqdm
import matplotlib.pyplot as plt
import wandb
from wandb.integration.sb3 import WandbCallback
from stable_baselines3.common.vec_env import VecVideoRecorder
import dill
import zipfile
# Initialize Weights & Biases
total_timesteps = 100000
config = {
"policy_type": "MlpPolicy",
"total_timesteps": total_timesteps,
"env_name": "CartPole-v1",
}
wandb.login()
run = wandb.init(
project="a2c-PandaReachJointsDense-v2",
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True, # auto-upload the videos of agents playing the game
save_code=True, # optional
)
env_id = "PandaReachJointsDense-v2"
# Register the environment
register(id=env_id, entry_point='gym.envs.classic_control:CartPoleEnv', max_episode_steps=500)
env = gym.make(env_id)
model = A2C("MlpPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(total_timesteps=total_timesteps, callback=WandbCallback(
gradient_save_freq=100,
model_save_path=f"models/{run.id}"))
# Mark the run as public in W&B project settings
run.finish()
vec_env = model.get_env()
obs = vec_env.reset()
for i in tqdm(range(1000)):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render()
def save_model(model, env_id): # use this function to save the model without wandb visualization
# Step 1: Serialize the model
model_bytes = dill.dumps(model)
# Step 2: Create a .zip file containing the serialized model
zip_filename = env_id + ".zip"
with zipfile.ZipFile(zip_filename, 'w') as zipf:
zipf.writestr("model.pkl", model_bytes)