diff --git a/README.md b/README.md index 1d483f66cd54c1d586b0147a4168c7ea459c3d2b..0625bf06137df83594a803662564fa9329ef6adf 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,24 @@ -# Reinforcement learning +# TD 1 : Hands-On Reinforcement Learning +This TD introduces different algorithms, frameworks and tools used in Reinforcement Learning. The methods are applied to the robotic field: a Cartpole and the PandaReachJointsDense environment. + +## REINFORCE +The REINFORCE algorithm is used to solve the Cartpole environment. The plot showing the total reward accross episodes can be seen below:  +The python script used is: reinforce_cartpole.py. + ## Familiarization with a complete RL pipeline: Application to training a robotic arm -### Get familiar with Hugging Face Hub +### Stable-Baselines3 and HuggingFace +In this section, the Stable-Baselines3 package is used to solve the Cartpole with the Advantage Actor-Critic (A2C) algorithm. +The python code used is: a2c_sb3_cartpole.py. + +The trained model is shared on HuggingFace, available on the following link: https://huggingface.co/oscarchaufour/a2c-CartPole-v1 + +### Weights & Biases +The Weights & Biases package is used to visualize the taining and the performances of a model. The link to the run visualization on WandB is: https://wandb.ai/oscar-chaufour/a2c-cartpole-v1?workspace=user-oscar-chaufour + +### Full workflow with panda-gym +The full training-visualization-sharing workflow is applied to the PandaReachJointsDense environment. -Link to the model on the hub : -# REINFORCE -Plot showing the total reward accross episodes:  -# A2C trained model -Link to the trained model (available on huggingFace): https://huggingface.co/oscarchaufour/a2c-CartPole-v1 diff --git a/a2c_sb3_panda_reach.py b/a2c_sb3_panda_reach.py new file mode 100644 index 0000000000000000000000000000000000000000..c3aa767acd0c048cb8155eea8ad5119b8f312755 --- /dev/null +++ b/a2c_sb3_panda_reach.py @@ -0,0 +1,64 @@ +import gym +import panda_gym +from stable_baselines3 import A2C +from huggingface_sb3 import package_to_hub, push_to_hub +from gym import envs +from gymnasium.envs.registration import register +from tqdm import tqdm +import matplotlib.pyplot as plt +import wandb +from wandb.integration.sb3 import WandbCallback +from stable_baselines3.common.vec_env import VecVideoRecorder +import dill +import zipfile + +# Initialize Weights & Biases +total_timesteps = 100000 +config = { + "policy_type": "MlpPolicy", + "total_timesteps": total_timesteps, + "env_name": "CartPole-v1", +} +wandb.login() +run = wandb.init( + project="a2c-PandaReachJointsDense-v2", + config=config, + sync_tensorboard=True, # auto-upload sb3's tensorboard metrics + monitor_gym=True, # auto-upload the videos of agents playing the game + save_code=True, # optional +) +env_id = "PandaReachJointsDense-v2" + +# Register the environment +register(id=env_id, entry_point='gym.envs.classic_control:CartPoleEnv', max_episode_steps=500) + +env = gym.make(env_id) + +model = A2C("MlpPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}") +model.learn(total_timesteps=total_timesteps, callback=WandbCallback( + gradient_save_freq=100, + model_save_path=f"models/{run.id}")) + + +# Mark the run as public in W&B project settings +run.finish() + +vec_env = model.get_env() +obs = vec_env.reset() + +for i in tqdm(range(1000)): + action, _state = model.predict(obs, deterministic=True) + obs, reward, done, info = vec_env.step(action) + vec_env.render() + +def save_model(model, env_id): # use this function to save the model without wandb visualization + # Step 1: Serialize the model + model_bytes = dill.dumps(model) + + # Step 2: Create a .zip file containing the serialized model + zip_filename = env_id + ".zip" + with zipfile.ZipFile(zip_filename, 'w') as zipf: + zipf.writestr("model.pkl", model_bytes) + + +