import gymnasium as gym from stable_baselines3 import A2C from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv import wandb import panda_gym from wandb.integration.sb3 import WandbCallback from huggingface_hub import login from huggingface_sb3 import push_to_hub if __name__ == "__main__": # Log in HF login() # Initialize a new wandb run # Configs config = { "policy_type": "MultiInputPolicy", "total_timesteps": 500_000, "env_name": "PandaReachJointsDense-v3", } # Initialize a new wandb run run = wandb.init( project="a2c_sb3_panda_reach", config=config, sync_tensorboard=True, monitor_gym=True, ) def make_env(): env = gym.make(config["env_name"]) env = Monitor(env) # Record stats such as returns return env env = DummyVecEnv([make_env]) model = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}") # WB callback wandb_callback = WandbCallback( gradient_save_freq=100, model_save_path=f"models/{run.id}", verbose=2, ) model.learn( total_timesteps=config["total_timesteps"], callback=wandb_callback ) model.save("PandaReachJointsDense_1.zip") # Finish the run run.finish() # Upload on HF push_to_hub( repo_id="CorentinGst/PandaReachJointsDense_1", filename="PandaReachJointsDense_1.zip", commit_message="Add my 1st model trained on PandaReachJointsDense-v3 env", )