Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • main
1 result

Target

Select target project
  • loestrei/mso_3_4-td1
  • edelland/mso_3_4-td1
  • schneidl/mso_3_4-td1
  • epaganel/mso_3_4-td1
  • asennevi/armand-senneville-mso-3-4-td-1
  • hchauvin/mso_3_4-td1
  • mbabay/mso_3_4-td1
  • ochaufou/mso_3_4-td1
  • cgerest/hands-on-rl
  • robertr/mso_3_4-td1
  • kmajdi/mso_3_4-td1
  • jseksik/hands-on-rl
  • coulonj/mso_3_4-td1
  • tdesgreys/mso_3_4-td1
14 results
Select Git revision
  • main
1 result
Show changes
Commits on Source (3)
%% Cell type:code id: tags:
```
!pip install torch
```
%% Output
Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.1.0+cu121)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.13.1)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.5.0)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.2.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.3)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)
Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.1.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.4)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)
%% Cell type:code id: tags:
```
!pip install gym==0.26.2
!pip install pyglet==2.0.10
!pip install pygame==2.5.2
!pip install PyQt5
```
%% Output
Collecting gym==0.26.2
Downloading gym-0.26.2.tar.gz (721 kB)
 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 721.7/721.7 kB 4.1 MB/s eta 0:00:00
[?25h Installing build dependencies ... [?25l[?25hdone
Getting requirements to build wheel ... [?25l[?25hdone
Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.10/dist-packages (from gym==0.26.2) (1.23.5)
Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from gym==0.26.2) (2.2.1)
Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.10/dist-packages (from gym==0.26.2) (0.0.8)
Building wheels for collected packages: gym
Building wheel for gym (pyproject.toml) ... [?25l[?25hdone
Created wheel for gym: filename=gym-0.26.2-py3-none-any.whl size=827620 sha256=c0e52d00d327e4107d9d7487856d18889281e05f722af341747342271291b415
Stored in directory: /root/.cache/pip/wheels/b9/22/6d/3e7b32d98451b4cd9d12417052affbeeeea012955d437da1da
Successfully built gym
Installing collected packages: gym
Attempting uninstall: gym
Found existing installation: gym 0.25.2
Uninstalling gym-0.25.2:
Successfully uninstalled gym-0.25.2
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
dopamine-rl 4.0.6 requires gym<=0.25.2, but you have gym 0.26.2 which is incompatible.
Successfully installed gym-0.26.2
Collecting pyglet==2.0.10
Downloading pyglet-2.0.10-py3-none-any.whl (858 kB)
 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 858.3/858.3 kB 5.6 MB/s eta 0:00:00
[?25hInstalling collected packages: pyglet
Successfully installed pyglet-2.0.10
Requirement already satisfied: pygame==2.5.2 in /usr/local/lib/python3.10/dist-packages (2.5.2)
Collecting PyQt5
Downloading PyQt5-5.15.10-cp37-abi3-manylinux_2_17_x86_64.whl (8.2 MB)
 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.2/8.2 MB 23.6 MB/s eta 0:00:00
[?25hCollecting PyQt5-sip<13,>=12.13 (from PyQt5)
Downloading PyQt5_sip-12.13.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.whl (338 kB)
 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 338.1/338.1 kB 29.8 MB/s eta 0:00:00
[?25hCollecting PyQt5-Qt5>=5.15.2 (from PyQt5)
Downloading PyQt5_Qt5-5.15.2-py3-none-manylinux2014_x86_64.whl (59.9 MB)
 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 59.9/59.9 MB 8.4 MB/s eta 0:00:00
[?25hInstalling collected packages: PyQt5-Qt5, PyQt5-sip, PyQt5
Successfully installed PyQt5-5.15.10 PyQt5-Qt5-5.15.2 PyQt5-sip-12.13.0
%% Cell type:markdown id: tags:
The REINFORCE algorithm (also known as Vanilla Policy Gradient) is a policy gradient method that optimizes the policy directly using gradient descent. The following is the pseudocode of the REINFORCE algorithm:
%% Cell type:markdown id: tags:
**Setup the CartPole environment**
**Setup the agent as a simple neural network with:**
- One fully connected layer with 128 units and ReLU activation followed by a dropout layer
- One fully connected layer followed by softmax activation
**Repeat 500 times:**
Reset the environment
Reset the buffer
Repeat until the end of the episode:
- Compute action probabilities
- Sample the action based on the probabilities and
store its probability in the buffer
- Step the environment with the action
- Compute and store in the buffer the return using gamma=0.99
Normalize the return
Compute the policy loss as -sum(log(prob) * return)
Update the policy using an Adam optimizer and a learning rate of 5e-3
%% Cell type:code id: tags:
```
import gym, pygame, numpy as np, matplotlib.pyplot as plt
import torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim
from torch.distributions import Categorical
# Setup the CartPole environment
env = gym.make("CartPole-v1", render_mode="human")
# Setup the agent as a simple neural network
class Agent(nn.Module) :
def __init__(self) :
super(Agent, self).__init__()
self.FC1 = nn.Linear(env.observation_space.shape[0], 128)
self.FC2 = nn.Linear(128, env.action_space.n)
def forward(self, x) :
x = self.FC1(x)
x = F.relu(x)
x = F.dropout(x)
x = self.FC2(x)
x = F.softmax(x, dim=1)
return x
# Creation of the agent
agent = Agent()
rewards_tot = []
# Repeat 500 times
for i in range(500):
# Reset the environment
obs = env.reset()
obs = obs[0] if isinstance(obs, tuple) else obs
# Reset the buffer
rewards, log_probs_list, terminated, step = [], [], False, 0
# Repeat until the end of the episode
while terminated == False and step < 500:
step += 1
# Compute action probabilities
obs_tensor = torch.FloatTensor(obs).unsqueeze(0)
log_probs = agent(obs_tensor)
probs = torch.exp(log_probs)
# Sample the action based on the probabilities and store probability
action = torch.multinomial(probs, 1).item()
# Step the environment with the action
new_obs, reward, terminated, _ = env.step(action)
env.render()
# Compute and store the return in the buffer
rewards.append(reward)
log_probs_list.append(log_probs[0, action])
obs = new_obs
# Normalize the return
R = 0
returns = []
for r_i in rewards[::-1] :
R = r_i + 0.99*R
returns.insert(0, R)
returns = torch.tensor(returns)
returns = 1/(returns.std(dim=0) + 1e-9) * (returns - returns.mean(dim=0))
rewards_tot.append(sum(rewards))
# Compute the policy loss
loss = -torch.sum(torch.stack(log_probs_list) * torch.FloatTensor(returns))
# Update policy with an Adam optimizer
optimizer = optim.Adam(agent.parameters(), lr=5e-3)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Close the environment
env.close()
# Reward plot
plt.figure()
plt.plot(rewards_tot)
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.title('Reinforcement rewards')
plt.show()
```
%% Output
%% Cell type:markdown id: tags:
We can see that even if the average reward per episode increases, it does not reach the max reward value, and still oscillates much after 500 episodes.
%% Cell type:markdown id: tags:
**Familiarization with a complete RL pipeline: Application to training a robotic arm**
In this section, you will use the Stable-Baselines3 package to train a robotic arm using RL. You'll get familiar with several widely-used tools for training, monitoring and sharing machine learning models.
%% Cell type:code id: tags:
```
!pip install stable-baselines3
!pip install moviepy
!pip install huggingface-sb3==2.3.1
!pip install wandb
```
%% Output
Requirement already satisfied: stable-baselines3 in /usr/local/lib/python3.10/dist-packages (2.2.1)
Requirement already satisfied: gymnasium<0.30,>=0.28.1 in /usr/local/lib/python3.10/dist-packages (from stable-baselines3) (0.29.1)
Requirement already satisfied: numpy>=1.20 in /usr/local/lib/python3.10/dist-packages (from stable-baselines3) (1.25.2)
Requirement already satisfied: torch>=1.13 in /usr/local/lib/python3.10/dist-packages (from stable-baselines3) (2.1.0+cu121)
Requirement already satisfied: cloudpickle in /usr/local/lib/python3.10/dist-packages (from stable-baselines3) (2.2.1)
Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from stable-baselines3) (1.5.3)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from stable-baselines3) (3.7.1)
Requirement already satisfied: typing-extensions>=4.3.0 in /usr/local/lib/python3.10/dist-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (4.10.0)
Requirement already satisfied: farama-notifications>=0.0.1 in /usr/local/lib/python3.10/dist-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (0.0.4)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.13->stable-baselines3) (3.13.1)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.13->stable-baselines3) (1.12)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13->stable-baselines3) (3.2.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13->stable-baselines3) (3.1.3)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.13->stable-baselines3) (2023.6.0)
Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13->stable-baselines3) (2.1.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->stable-baselines3) (1.2.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->stable-baselines3) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->stable-baselines3) (4.49.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->stable-baselines3) (1.4.5)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->stable-baselines3) (23.2)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->stable-baselines3) (9.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->stable-baselines3) (3.1.1)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->stable-baselines3) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->stable-baselines3) (2023.4)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib->stable-baselines3) (1.16.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13->stable-baselines3) (2.1.5)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.13->stable-baselines3) (1.3.0)
Requirement already satisfied: moviepy in /usr/local/lib/python3.10/dist-packages (1.0.3)
Requirement already satisfied: decorator<5.0,>=4.0.2 in /usr/local/lib/python3.10/dist-packages (from moviepy) (4.4.2)
Requirement already satisfied: tqdm<5.0,>=4.11.2 in /usr/local/lib/python3.10/dist-packages (from moviepy) (4.66.2)
Requirement already satisfied: requests<3.0,>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from moviepy) (2.31.0)
Requirement already satisfied: proglog<=1.0.0 in /usr/local/lib/python3.10/dist-packages (from moviepy) (0.1.10)
Requirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.10/dist-packages (from moviepy) (1.25.2)
Requirement already satisfied: imageio<3.0,>=2.5 in /usr/local/lib/python3.10/dist-packages (from moviepy) (2.31.6)
Requirement already satisfied: imageio-ffmpeg>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from moviepy) (0.4.9)
Requirement already satisfied: pillow<10.1.0,>=8.3.2 in /usr/local/lib/python3.10/dist-packages (from imageio<3.0,>=2.5->moviepy) (9.4.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from imageio-ffmpeg>=0.2.0->moviepy) (67.7.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3.0,>=2.8.1->moviepy) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3.0,>=2.8.1->moviepy) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3.0,>=2.8.1->moviepy) (2.0.7)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3.0,>=2.8.1->moviepy) (2024.2.2)
Requirement already satisfied: huggingface-sb3==2.3.1 in /usr/local/lib/python3.10/dist-packages (2.3.1)
Requirement already satisfied: huggingface-hub~=0.8 in /usr/local/lib/python3.10/dist-packages (from huggingface-sb3==2.3.1) (0.20.3)
Requirement already satisfied: pyyaml~=6.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-sb3==2.3.1) (6.0.1)
Requirement already satisfied: wasabi in /usr/local/lib/python3.10/dist-packages (from huggingface-sb3==2.3.1) (1.1.2)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from huggingface-sb3==2.3.1) (1.25.2)
Requirement already satisfied: cloudpickle>=1.6 in /usr/local/lib/python3.10/dist-packages (from huggingface-sb3==2.3.1) (2.2.1)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (3.13.1)
Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (2023.6.0)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (2.31.0)
Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (4.66.2)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (4.10.0)
Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (23.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (2.0.7)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (2024.2.2)
Collecting wandb
Downloading wandb-0.16.4-py3-none-any.whl (2.2 MB)
 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 24.7 MB/s eta 0:00:00
[?25hRequirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.7)
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
Downloading GitPython-3.1.42-py3-none-any.whl (195 kB)
 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 195.4/195.4 kB 22.5 MB/s eta 0:00:00
[?25hRequirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (2.31.0)
Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (5.9.5)
Collecting sentry-sdk>=1.0.0 (from wandb)
Downloading sentry_sdk-1.40.6-py2.py3-none-any.whl (258 kB)
 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 258.5/258.5 kB 17.5 MB/s eta 0:00:00
[?25hCollecting docker-pycreds>=0.4.0 (from wandb)
Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from wandb) (6.0.1)
Collecting setproctitle (from wandb)
Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (67.7.2)
Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb) (1.4.4)
Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)
Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)
Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb)
Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)
 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 62.7/62.7 kB 8.0 MB/s eta 0:00:00
[?25hRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2.0.7)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2024.2.2)
Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb)
Downloading smmap-5.0.1-py3-none-any.whl (24 kB)
Installing collected packages: smmap, setproctitle, sentry-sdk, docker-pycreds, gitdb, GitPython, wandb
Successfully installed GitPython-3.1.42 docker-pycreds-0.4.0 gitdb-4.0.11 sentry-sdk-1.40.6 setproctitle-1.3.3 smmap-5.0.1 wandb-0.16.4
%% Cell type:code id: tags:
```
import wandb, gymnasium as gym
from stable_baselines3 import A2C
from stable_baselines3.common.evaluation import evaluate_policy
from huggingface_hub import hf_api
from wandb.integration.sb3 import WandbCallback
# Setup the Cartpole environment
env = gym.make("CartPole-v1", render_mode="rgb_array")
# Choosing the model
model = A2C("MlpPolicy", env, verbose=1)
# Printing initial reward
reward_before_moy, _ = evaluate_policy(model, env, n_eval_episodes=10)
print(f"Mean reward before training: {reward_before_moy:.2f}")
# Model training during 10000 timesteps
model.learn(total_timesteps=10_000)
# Printing reward after training
reward_after_moy, _ = evaluate_policy(model, env, n_eval_episodes=10)
print(f"Mean reward after training: {reward_after_moy:.2f}")
# Upload and save model
# Saving the trained model
model_save_path = "model"
model.save(model_save_path)
model_path = "model.zip"
# Creating repository
repo_name="BE-RL"
rep = hf_api.create_repo(token="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", repo_id=repo_name)
# Uploading model in repository
repo_id="hchauvin78/BE-RL"
hf_api.upload_file(token="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", repo_id=repo_id, path_or_fileobj=model_path, path_in_repo=repo_name)
# Training with WandB
# Initializing WandB
wandb.init(project="cartpole-training", entity="hchauvin78", anonymous="allow")
#Configuring WandB
config = wandb.config
config.learning_rate = 0.001
config.gamma = 0.99
config.n_steps = 500
#Monitoring model training with WandB
model = A2C('MlpPolicy', env, verbose=1, tensorboard_log="logs/")
episode_rewards = []
for i in range(25000):
obs = env.reset()[0]
reward_tot = 0
terminated = False
while terminated == False:
action, _ = model.predict(obs, deterministic=True)
obs, reward, terminated, info, _ = env.step(action)
reward_tot += reward
episode_rewards.append(reward_tot)
wandb.log({"Episode Reward": reward_tot, "Episode": i})
#Log mean reward every 10 episodes
if i % 10 == 0:
mean_reward = sum(episode_rewards[-10:]) / 10
wandb.log({"Mean Reward": mean_reward})
#Log final metrics to WandB
wandb.log({"Mean Reward": mean_reward})
#Finish WandB run
wandb.finish()
```
%% Output
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Mean reward before training: 355.20
------------------------------------
| rollout/ | |
| ep_len_mean | 31.5 |
| ep_rew_mean | 31.5 |
| time/ | |
| fps | 398 |
| iterations | 100 |
| time_elapsed | 1 |
| total_timesteps | 500 |
| train/ | |
| entropy_loss | -0.574 |
| explained_variance | 0.669 |
| learning_rate | 0.0007 |
| n_updates | 99 |
| policy_loss | 1.52 |
| value_loss | 6.1 |
------------------------------------
------------------------------------
| rollout/ | |
| ep_len_mean | 30.2 |
| ep_rew_mean | 30.2 |
| time/ | |
| fps | 470 |
| iterations | 200 |
| time_elapsed | 2 |
| total_timesteps | 1000 |
| train/ | |
| entropy_loss | -0.517 |
| explained_variance | -0.893 |
| learning_rate | 0.0007 |
| n_updates | 199 |
| policy_loss | 2.03 |
| value_loss | 11.1 |
------------------------------------
------------------------------------
| rollout/ | |
| ep_len_mean | 31.7 |
| ep_rew_mean | 31.7 |
| time/ | |
| fps | 504 |
| iterations | 300 |
| time_elapsed | 2 |
| total_timesteps | 1500 |
| train/ | |
| entropy_loss | -0.574 |
| explained_variance | 0.197 |
| learning_rate | 0.0007 |
| n_updates | 299 |
| policy_loss | 0.67 |
| value_loss | 3.55 |
------------------------------------
------------------------------------
| rollout/ | |
| ep_len_mean | 32.3 |
| ep_rew_mean | 32.3 |
| time/ | |
| fps | 525 |
| iterations | 400 |
| time_elapsed | 3 |
| total_timesteps | 2000 |
| train/ | |
| entropy_loss | -0.517 |
| explained_variance | -0.0114 |
| learning_rate | 0.0007 |
| n_updates | 399 |
| policy_loss | 1.48 |
| value_loss | 6.36 |
------------------------------------
------------------------------------
| rollout/ | |
| ep_len_mean | 32.8 |
| ep_rew_mean | 32.8 |
| time/ | |
| fps | 535 |
| iterations | 500 |
| time_elapsed | 4 |
| total_timesteps | 2500 |
| train/ | |
| entropy_loss | -0.494 |
| explained_variance | -0.0975 |
| learning_rate | 0.0007 |
| n_updates | 499 |
| policy_loss | 0.907 |
| value_loss | 5.78 |
------------------------------------
------------------------------------
| rollout/ | |
| ep_len_mean | 34.7 |
| ep_rew_mean | 34.7 |
| time/ | |
| fps | 541 |
| iterations | 600 |
| time_elapsed | 5 |
| total_timesteps | 3000 |
| train/ | |
| entropy_loss | -0.447 |
| explained_variance | -0.00207 |
| learning_rate | 0.0007 |
| n_updates | 599 |
| policy_loss | 1.28 |
| value_loss | 4.87 |
------------------------------------
------------------------------------
| rollout/ | |
| ep_len_mean | 35.5 |
| ep_rew_mean | 35.5 |
| time/ | |
| fps | 544 |
| iterations | 700 |
| time_elapsed | 6 |
| total_timesteps | 3500 |
| train/ | |
| entropy_loss | -0.302 |
| explained_variance | 0.0436 |
| learning_rate | 0.0007 |
| n_updates | 699 |
| policy_loss | 1.57 |
| value_loss | 4.52 |
------------------------------------
------------------------------------
| rollout/ | |
| ep_len_mean | 36.8 |
| ep_rew_mean | 36.8 |
| time/ | |
| fps | 549 |
| iterations | 800 |
| time_elapsed | 7 |
| total_timesteps | 4000 |
| train/ | |
| entropy_loss | -0.256 |
| explained_variance | -0.00302 |
| learning_rate | 0.0007 |
| n_updates | 799 |
| policy_loss | 1.88 |
| value_loss | 4.09 |
------------------------------------
------------------------------------
| rollout/ | |
| ep_len_mean | 38.6 |
| ep_rew_mean | 38.6 |
| time/ | |
| fps | 552 |
| iterations | 900 |
| time_elapsed | 8 |
| total_timesteps | 4500 |
| train/ | |
| entropy_loss | -0.229 |
| explained_variance | 0.037 |
| learning_rate | 0.0007 |
| n_updates | 899 |
| policy_loss | 1.99 |
| value_loss | 3.64 |
------------------------------------
------------------------------------
| rollout/ | |
| ep_len_mean | 40.3 |
| ep_rew_mean | 40.3 |
| time/ | |
| fps | 556 |
| iterations | 1000 |
| time_elapsed | 8 |
| total_timesteps | 5000 |
| train/ | |
| entropy_loss | -0.42 |
| explained_variance | 4.21e-05 |
| learning_rate | 0.0007 |
| n_updates | 999 |
| policy_loss | -4.77 |
| value_loss | 1.13e+03 |
------------------------------------
------------------------------------
| rollout/ | |
| ep_len_mean | 42.5 |
| ep_rew_mean | 42.5 |
| time/ | |
| fps | 558 |
| iterations | 1100 |
| time_elapsed | 9 |
| total_timesteps | 5500 |
| train/ | |
| entropy_loss | -0.473 |
| explained_variance | 4.71e-06 |
| learning_rate | 0.0007 |
| n_updates | 1099 |
| policy_loss | 0.298 |
| value_loss | 2.76 |
------------------------------------
-------------------------------------
| rollout/ | |
| ep_len_mean | 44.3 |
| ep_rew_mean | 44.3 |
| time/ | |
| fps | 561 |
| iterations | 1200 |
| time_elapsed | 10 |
| total_timesteps | 6000 |
| train/ | |
| entropy_loss | -0.384 |
| explained_variance | -0.000385 |
| learning_rate | 0.0007 |
| n_updates | 1199 |
| policy_loss | 0.415 |
| value_loss | 2.35 |
-------------------------------------
------------------------------------
| rollout/ | |
| ep_len_mean | 46.4 |
| ep_rew_mean | 46.4 |
| time/ | |
| fps | 559 |
| iterations | 1300 |
| time_elapsed | 11 |
| total_timesteps | 6500 |
| train/ | |
| entropy_loss | -0.342 |
| explained_variance | 0.000243 |
| learning_rate | 0.0007 |
| n_updates | 1299 |
| policy_loss | 0.578 |
| value_loss | 1.93 |
------------------------------------
------------------------------------
| rollout/ | |
| ep_len_mean | 52 |
| ep_rew_mean | 52 |
| time/ | |
| fps | 548 |
| iterations | 1400 |
| time_elapsed | 12 |
| total_timesteps | 7000 |
| train/ | |
| entropy_loss | -0.352 |
| explained_variance | 0.00346 |
| learning_rate | 0.0007 |
| n_updates | 1399 |
| policy_loss | 0.887 |
| value_loss | 1.55 |
------------------------------------
-------------------------------------
| rollout/ | |
| ep_len_mean | 55.2 |
| ep_rew_mean | 55.2 |
| time/ | |
| fps | 539 |
| iterations | 1500 |
| time_elapsed | 13 |
| total_timesteps | 7500 |
| train/ | |
| entropy_loss | -0.518 |
| explained_variance | -0.000267 |
| learning_rate | 0.0007 |
| n_updates | 1499 |
| policy_loss | 0.277 |
| value_loss | 1.21 |
-------------------------------------
------------------------------------
| rollout/ | |
| ep_len_mean | 60.7 |
| ep_rew_mean | 60.7 |
| time/ | |
| fps | 529 |
| iterations | 1600 |
| time_elapsed | 15 |
| total_timesteps | 8000 |
| train/ | |
| entropy_loss | -0.456 |
| explained_variance | 0.000455 |
| learning_rate | 0.0007 |
| n_updates | 1599 |
| policy_loss | 0.236 |
| value_loss | 0.918 |
------------------------------------
-------------------------------------
| rollout/ | |
| ep_len_mean | 64.2 |
| ep_rew_mean | 64.2 |
| time/ | |
| fps | 531 |
| iterations | 1700 |
| time_elapsed | 15 |
| total_timesteps | 8500 |
| train/ | |
| entropy_loss | -0.397 |
| explained_variance | -0.000141 |
| learning_rate | 0.0007 |
| n_updates | 1699 |
| policy_loss | 0.27 |
| value_loss | 0.668 |
-------------------------------------
------------------------------------
| rollout/ | |
| ep_len_mean | 67.1 |
| ep_rew_mean | 67.1 |
| time/ | |
| fps | 534 |
| iterations | 1800 |
| time_elapsed | 16 |
| total_timesteps | 9000 |
| train/ | |
| entropy_loss | -0.412 |
| explained_variance | 6.56e-07 |
| learning_rate | 0.0007 |
| n_updates | 1799 |
| policy_loss | 0.417 |
| value_loss | 0.45 |
------------------------------------
-------------------------------------
| rollout/ | |
| ep_len_mean | 75.1 |
| ep_rew_mean | 75.1 |
| time/ | |
| fps | 536 |
| iterations | 1900 |
| time_elapsed | 17 |
| total_timesteps | 9500 |
| train/ | |
| entropy_loss | -0.365 |
| explained_variance | -1.88e-05 |
| learning_rate | 0.0007 |
| n_updates | 1899 |
| policy_loss | 0.139 |
| value_loss | 0.274 |
-------------------------------------
------------------------------------
| rollout/ | |
| ep_len_mean | 79.9 |
| ep_rew_mean | 79.9 |
| time/ | |
| fps | 538 |
| iterations | 2000 |
| time_elapsed | 18 |
| total_timesteps | 10000 |
| train/ | |
| entropy_loss | -0.356 |
| explained_variance | 5.39e-05 |
| learning_rate | 0.0007 |
| n_updates | 1999 |
| policy_loss | 0.101 |
| value_loss | 0.148 |
------------------------------------
Mean reward after training: 256.70
/usr/local/lib/python3.10/dist-packages/notebook/utils.py:280: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
return LooseVersion(v) >= LooseVersion(check)
wandb: (1) Private W&B dashboard, no account required
wandb: (2) Use an existing W&B account
wandb: Enter your choice: 1
wandb: You chose 'Private W&B dashboard, no account required'
wandb: Appending key for api.wandb.ai to your netrc file: /root/.netrc
Problem at: <ipython-input-28-b66b1c5d0647> 35 <cell line: 35>
---------------------------------------------------------------------------
CommError Traceback (most recent call last)
<ipython-input-28-b66b1c5d0647> in <cell line: 35>()
33 # Training with WandB
34 # Initializing WandB
---> 35 wandb.init(project="cartpole-training", entity="hchauvin78", anonymous="allow")
36
37 #Configuring WandB
/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_init.py in init(job_type, dir, config, project, entity, reinit, tags, group, name, notes, magic, config_exclude_keys, config_include_keys, anonymous, mode, allow_val_change, resume, force, tensorboard, sync_tensorboard, monitor_gym, save_code, id, settings)
1193 if logger is not None:
1194 logger.exception(str(e))
-> 1195 raise e
1196 except KeyboardInterrupt as e:
1197 assert logger
/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_init.py in init(job_type, dir, config, project, entity, reinit, tags, group, name, notes, magic, config_exclude_keys, config_include_keys, anonymous, mode, allow_val_change, resume, force, tensorboard, sync_tensorboard, monitor_gym, save_code, id, settings)
1174 except_exit = wi.settings._except_exit
1175 try:
-> 1176 run = wi.init()
1177 except_exit = wi.settings._except_exit
1178 except (KeyboardInterrupt, Exception) as e:
/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_init.py in init(self)
783 backend.cleanup()
784 self.teardown()
--> 785 raise error
786
787 assert run_result is not None # for mypy
CommError: It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 404: Not Found)
%% Cell type:markdown id: tags:
I am struggling to access my Hugging Face account even though I generated tokens... as I cannot understand what is happening I will stop there :'(
import wandb, gymnasium as gym
from stable_baselines3 import A2C
from stable_baselines3.common.evaluation import evaluate_policy
from huggingface_hub import hf_api
from wandb.integration.sb3 import WandbCallback
# Setup the Cartpole environment
env = gym.make("CartPole-v1", render_mode="rgb_array")
# Choosing the model
model = A2C("MlpPolicy", env, verbose=1)
# Printing initial reward
reward_before_moy, _ = evaluate_policy(model, env, n_eval_episodes=10)
print(f"Mean reward before training: {reward_before_moy:.2f}")
# Model training during 10000 timesteps
model.learn(total_timesteps=10_000)
# Printing reward after training
reward_after_moy, _ = evaluate_policy(model, env, n_eval_episodes=10)
print(f"Mean reward after training: {reward_after_moy:.2f}")
# Upload and save model
# Saving the trained model
model_save_path = "model"
model.save(model_save_path)
model_path = "model.zip"
# Creating repository
repo_name="BE-RL"
rep = hf_api.create_repo(token="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", repo_id=repo_name)
# Uploading model in repository
repo_id="hchauvin78/BE-RL"
hf_api.upload_file(token="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", repo_id=repo_id, path_or_fileobj=model_path, path_in_repo=repo_name)
# Training with WandB
# Initializing WandB
wandb.init(project="cartpole-training", entity="hchauvin78", anonymous="allow")
#Configuring WandB
config = wandb.config
config.learning_rate = 0.001
config.gamma = 0.99
config.n_steps = 500
#Monitoring model training with WandB
model = A2C('MlpPolicy', env, verbose=1, tensorboard_log="logs/")
episode_rewards = []
for i in range(25000):
obs = env.reset()[0]
reward_tot = 0
terminated = False
while terminated == False:
action, _ = model.predict(obs, deterministic=True)
obs, reward, terminated, info, _ = env.step(action)
reward_tot += reward
episode_rewards.append(reward_tot)
wandb.log({"Episode Reward": reward_tot, "Episode": i})
#Log mean reward every 10 episodes
if i % 10 == 0:
mean_reward = sum(episode_rewards[-10:]) / 10
wandb.log({"Mean Reward": mean_reward})
#Log final metrics to WandB
wandb.log({"Mean Reward": mean_reward})
#Finish WandB run
wandb.finish()
import gym, pygame, numpy as np, matplotlib.pyplot as plt
import torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim
from torch.distributions import Categorical
# Setup the CartPole environment
env = gym.make("CartPole-v1", render_mode="human")
# Setup the agent as a simple neural network
class Agent(nn.Module) :
def __init__(self) :
super(Agent, self).__init__()
self.FC1 = nn.Linear(env.observation_space.shape[0], 128)
self.FC2 = nn.Linear(128, env.action_space.n)
def forward(self, x) :
x = self.FC1(x)
x = F.relu(x)
x = F.dropout(x)
x = self.FC2(x)
x = F.softmax(x, dim=1)
return x
# Creation of the agent
agent = Agent()
rewards_tot = []
# Repeat 500 times
for i in range(500):
# Reset the environment
obs = env.reset()
obs = obs[0] if isinstance(obs, tuple) else obs
# Reset the buffer
rewards, log_probs_list, terminated, step = [], [], False, 0
# Repeat until the end of the episode
while terminated == False and step < 500:
step += 1
# Compute action probabilities
obs_tensor = torch.FloatTensor(obs).unsqueeze(0)
log_probs = agent(obs_tensor)
probs = torch.exp(log_probs)
# Sample the action based on the probabilities and store probability
action = torch.multinomial(probs, 1).item()
# Step the environment with the action
new_obs, reward, terminated, _ = env.step(action)
env.render()
# Compute and store the return in the buffer
rewards.append(reward)
log_probs_list.append(log_probs[0, action])
obs = new_obs
# Normalize the return
R = 0
returns = []
for r_i in rewards[::-1] :
R = r_i + 0.99*R
returns.insert(0, R)
returns = torch.tensor(returns)
returns = 1/(returns.std(dim=0) + 1e-9) * (returns - returns.mean(dim=0))
rewards_tot.append(sum(rewards))
# Compute the policy loss
loss = -torch.sum(torch.stack(log_probs_list) * torch.FloatTensor(returns))
# Update policy with an Adam optimizer
optimizer = optim.Adam(agent.parameters(), lr=5e-3)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Close the environment
env.close()
# Reward plot
plt.figure()
plt.plot(rewards_tot)
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.title('Reinforcement rewards')
plt.show()
\ No newline at end of file