Skip to content
Snippets Groups Projects
Commit cb94cdbe authored by oscarchaufour's avatar oscarchaufour
Browse files

update

parent 4a98fcec
No related branches found
No related tags found
No related merge requests found
# Reinforcement learning
# TD 1 : Hands-On Reinforcement Learning
This TD introduces different algorithms, frameworks and tools used in Reinforcement Learning. The methods are applied to the robotic field: a Cartpole and the PandaReachJointsDense environment.
## REINFORCE
The REINFORCE algorithm is used to solve the Cartpole environment. The plot showing the total reward accross episodes can be seen below: ![Alt text](images/reinforce_rewards.png)
The python script used is: reinforce_cartpole.py.
## Familiarization with a complete RL pipeline: Application to training a robotic arm
### Get familiar with Hugging Face Hub
### Stable-Baselines3 and HuggingFace
In this section, the Stable-Baselines3 package is used to solve the Cartpole with the Advantage Actor-Critic (A2C) algorithm.
The python code used is: a2c_sb3_cartpole.py.
The trained model is shared on HuggingFace, available on the following link: https://huggingface.co/oscarchaufour/a2c-CartPole-v1
### Weights & Biases
The Weights & Biases package is used to visualize the taining and the performances of a model. The link to the run visualization on WandB is: https://wandb.ai/oscar-chaufour/a2c-cartpole-v1?workspace=user-oscar-chaufour
### Full workflow with panda-gym
The full training-visualization-sharing workflow is applied to the PandaReachJointsDense environment.
Link to the model on the hub :
# REINFORCE
Plot showing the total reward accross episodes: ![Alt text](images/reinforce_rewards.png)
# A2C trained model
Link to the trained model (available on huggingFace): https://huggingface.co/oscarchaufour/a2c-CartPole-v1
import gym
import panda_gym
from stable_baselines3 import A2C
from huggingface_sb3 import package_to_hub, push_to_hub
from gym import envs
from gymnasium.envs.registration import register
from tqdm import tqdm
import matplotlib.pyplot as plt
import wandb
from wandb.integration.sb3 import WandbCallback
from stable_baselines3.common.vec_env import VecVideoRecorder
import dill
import zipfile
# Initialize Weights & Biases
total_timesteps = 100000
config = {
"policy_type": "MlpPolicy",
"total_timesteps": total_timesteps,
"env_name": "CartPole-v1",
}
wandb.login()
run = wandb.init(
project="a2c-PandaReachJointsDense-v2",
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True, # auto-upload the videos of agents playing the game
save_code=True, # optional
)
env_id = "PandaReachJointsDense-v2"
# Register the environment
register(id=env_id, entry_point='gym.envs.classic_control:CartPoleEnv', max_episode_steps=500)
env = gym.make(env_id)
model = A2C("MlpPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(total_timesteps=total_timesteps, callback=WandbCallback(
gradient_save_freq=100,
model_save_path=f"models/{run.id}"))
# Mark the run as public in W&B project settings
run.finish()
vec_env = model.get_env()
obs = vec_env.reset()
for i in tqdm(range(1000)):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render()
def save_model(model, env_id): # use this function to save the model without wandb visualization
# Step 1: Serialize the model
model_bytes = dill.dumps(model)
# Step 2: Create a .zip file containing the serialized model
zip_filename = env_id + ".zip"
with zipfile.ZipFile(zip_filename, 'w') as zipf:
zipf.writestr("model.pkl", model_bytes)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment