Skip to content
Snippets Groups Projects
Select Git revision
  • de3fdf11abb93881b136f5e38c2ff609d82107fb
  • main default protected
2 results

a2c_sb3_panda_reach.py

Blame
  • Forked from Dellandrea Emmanuel / MSO_3_4-TD1
    Source project has a limited visibility.
    a2c_sb3_panda_reach.py 1.84 KiB
    import gym
    import panda_gym
    from stable_baselines3 import A2C
    from huggingface_sb3 import package_to_hub, push_to_hub
    from gym import envs
    from gymnasium.envs.registration import register
    from tqdm import tqdm
    import matplotlib.pyplot as plt
    import wandb
    from wandb.integration.sb3 import WandbCallback
    from stable_baselines3.common.vec_env import VecVideoRecorder
    import dill
    import zipfile
    
    # Initialize Weights & Biases
    total_timesteps = 100000
    config = {
        "policy_type": "MlpPolicy",
        "total_timesteps": total_timesteps,
        "env_name": "CartPole-v1",
    }
    wandb.login()
    run = wandb.init(
        project="a2c-PandaReachJointsDense-v2",
        config=config,
        sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
        monitor_gym=True,  # auto-upload the videos of agents playing the game
        save_code=True,  # optional
    )
    env_id = "PandaReachJointsDense-v2"
    
    # Register the environment
    register(id=env_id, entry_point='gym.envs.classic_control:CartPoleEnv', max_episode_steps=500)
    
    env = gym.make(env_id)
    
    model = A2C("MlpPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}")
    model.learn(total_timesteps=total_timesteps, callback=WandbCallback(
            gradient_save_freq=100,
            model_save_path=f"models/{run.id}"))
    
    
    # Mark the run as public in W&B project settings
    run.finish()
    
    vec_env = model.get_env()
    obs = vec_env.reset()
    
    for i in tqdm(range(1000)):
        action, _state = model.predict(obs, deterministic=True)
        obs, reward, done, info = vec_env.step(action)
        vec_env.render()
    
    def save_model(model, env_id): # use this function to save the model without wandb visualization 
        # Step 1: Serialize the model
        model_bytes = dill.dumps(model)
    
        # Step 2: Create a .zip file containing the serialized model
        zip_filename = env_id + ".zip"
        with zipfile.ZipFile(zip_filename, 'w') as zipf:
            zipf.writestr("model.pkl", model_bytes)