Skip to content
Snippets Groups Projects
Commit 0a43d63f authored by hamza masmoudi's avatar hamza masmoudi
Browse files

Initial commit with all necessary files

parents
No related branches found
No related tags found
No related merge requests found
# Ignorer les fichiers de l'environnement virtuel
venv/
# Ignorer les fichiers de logs de W&B
wandb/
# Ignorer les fichiers générés
*.zip
*.pth
# Ignorer les fichiers de configuration PyCharm
.idea/
# Ignorer les fichiers de configuration VSCode
.vscode/
# Ignorer les fichiers de cache Python
__pycache__/
*.pyc
# Ignorer les fichiers temporaires et de build
*.log
*.tmp
*.bak
# Ignorer les fichiers de configuration de l'utilisateur
.DS_Store
README.md 0 → 100644
# TD 1 : Hands-On Reinforcement Learning
MSO 3.4 Apprentissage Automatique
#
In this hands-on project, we will first implement a simple RL algorithm and apply it to solve the CartPole-v1 environment. Once we become familiar with the basic workflow, we will learn to use various tools for machine learning model training, monitoring, and sharing, by applying these tools to train a robotic arm.
## To be handed in
This work must be done individually. The expected output is a repository named `hands-on-rl` on https://gitlab.ec-lyon.fr.
We assume that `git` is installed, and that you are familiar with the basic `git` commands. (Optionnaly, you can use GitHub Desktop.)
We also assume that you have access to the [ECL GitLab](https://gitlab.ec-lyon.fr/). If necessary, please consult [this tutorial](https://gitlab.ec-lyon.fr/edelland/inf_tc2/-/blob/main/Tutoriel_gitlab/tutoriel_gitlab.md).
Your repository must contain a `README.md` file that explains **briefly** the successive steps of the project. It must be private, so you need to add your teacher as "developer" member.
Throughout the subject, you will find a 🛠 symbol indicating that a specific production is expected.
The last commit is due before 11:59 pm on March 17, 2025. Subsequent commits will not be considered.
> ⚠️ **Warning**
> Ensure that you only commit the files that are requested. For example, your directory should not contain the generated `.zip` files, nor the `runs` folder... At the end, your repository must contain one `README.md`, three python scripts, and optionally image files for the plots.
## Before you start
Make sure you know the basics of Reinforcement Learning. In case of need, you can refer to the [introduction of the Hugging Face RL course](https://huggingface.co/blog/deep-rl-intro).
## Introduction to Gym
[Gym](https://gymnasium.farama.org/) is a framework for developing and evaluating reinforcement learning environments. It offers various environments, including classic control and toy text scenarios, to test RL algorithms.
### Installation
We recommend to use Python virtual environnements to install the required modules : https://docs.python.org/3/library/venv.html or https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html if you are using conda.
First, install Pytorch : https://pytorch.org/get-started/locally.
Then install the following modules :
```sh
pip install gymnasium
```
```sh
pip install "gymnasium[classic-control]"
```
### Usage
Here is an example of how to use Gym to solve the `CartPole-v1` environment [Documentation](https://gymnasium.farama.org/environments/classic_control/cart_pole/):
```python
import gymnasium as gym
# Create the environment
env = gym.make("CartPole-v1", render_mode="human")
# Reset the environment and get the initial observation
observation = env.reset()
for _ in range(100):
# Select a random action from the action space
action = env.action_space.sample()
# Apply the action to the environment
# Returns next observation, reward, done signal (indicating
# if the episode has ended), and an additional info dictionary
observation, reward, terminated, truncated, info = env.step(action)
# Render the environment to visualize the agent's behavior
env.render()
if terminated:
# Terminated before max step
break
env.close()
```
## REINFORCE
The REINFORCE algorithm (also known as Vanilla Policy Gradient) is a policy gradient method that optimizes the policy directly using gradient descent. The following is the pseudocode of the REINFORCE algorithm:
```txt
Setup the CartPole environment
Setup the agent as a simple neural network with:
- One fully connected layer with 128 units and ReLU activation followed by a dropout layer
- One fully connected layer followed by softmax activation
Repeat 500 times:
Reset the environment
Reset the buffer
Repeat until the end of the episode:
Compute action probabilities
Sample the action based on the probabilities and store its probability in the buffer
Step the environment with the action
Compute and store in the buffer the return using gamma=0.99
Normalize the return
Compute the policy loss as -sum(log(prob) * return)
Update the policy using an Adam optimizer and a learning rate of 5e-3
Save the model weights
```
To learn more about REINFORCE, you can refer to [this unit](https://huggingface.co/learn/deep-rl-course/unit4/policy-gradient).
> 🛠 **To be handed in**
> Use PyTorch to implement REINFORCE and solve the CartPole environement. Share the code in `reinforce_cartpole.py`, and share a plot showing the total reward accross episodes in the `README.md`. Also, share a file `reinforce_cartpole.pth` containing the learned weights. For saving and loading PyTorch models, check [this tutorial](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference)
## Model Evaluation
Now that you have trained your model, it is time to evaluate its performance. Run it with rendering for a few trials and see if the policy is capable of completing the task.
> 🛠 **To be handed in**
> Implement a script which loads your saved model and use it to solve the cartpole enviroment. Run 100 evaluations and share the final success rate across all evaluations in the `README.md`. Share the code in `evaluate_reinforce_cartpole.py`.
## Familiarization with a complete RL pipeline: Application to training a robotic arm
In this section, you will use the Stable-Baselines3 package to train a robotic arm using RL. You'll get familiar with several widely-used tools for training, monitoring and sharing machine learning models.
### Get familiar with Stable-Baselines3
Stable-Baselines3 (SB3) is a high-level RL library that provides various algorithms and integrated tools to easily train and test reinforcement learning models.
#### Installation
```sh
pip install stable-baselines3
pip install "stable-baselines3[extra]"
pip install moviepy
```
#### Usage
Use the [Stable-Baselines3 documentation](https://stable-baselines3.readthedocs.io/en/master/) to implement the code to solve the CartPole environment with the Advantage Actor-Critic (A2C) algorithm.
> 🛠 **To be handed in**
> Store the code in `a2c_sb3_cartpole.py`. Unless otherwise stated, you'll work upon this file for the next sections.
### Get familiar with Hugging Face Hub
Hugging Face Hub is a platform for easy sharing and versioning of trained machine learning models. With Hugging Face Hub, you can quickly and easily share your models with others and make them usable through the API. For example, see the trained A2C agent for CartPole: https://huggingface.co/sb3/a2c-CartPole-v1. Hugging Face Hub provides an API to download and upload SB3 models.
#### Installation of `huggingface_sb3`
```sh
pip install huggingface-sb3
```
#### Upload the model on the Hub
Follow the [Hugging Face Hub documentation](https://huggingface.co/docs/hub/stable-baselines3) to upload the previously learned model to the Hub.
> 🛠 **To be handed in**
> Link the trained model in the `README.md` file.
> 📝 **Note**
> [RL-Zoo3](https://stable-baselines3.readthedocs.io/en/master/guide/rl_zoo.html) provides more advanced features to save hyperparameters, generate renderings and metrics. Feel free to try them.
### Get familiar with Weights & Biases
Weights & Biases (W&B) is a tool for machine learning experiment management. With W&B, you can track and compare your experiments, visualize your model training and performance.
#### Installation
You'll need to install both `wand` and `tensorboar`.
```shell
pip install wandb
```
Use the documentation of [Stable-Baselines3](https://stable-baselines3.readthedocs.io/en/master/) and [Weights & Biases](https://docs.wandb.ai/guides/integrations/stable-baselines-3) to track the CartPole training. Make the run public.
🛠 Share the link of the wandb run in the `README.md` file.
> ⚠️ **Warning**
> Make sure to make the run public! If it is not possible (due to the restrictions on your account), you can create a WandB [report](https://docs.wandb.ai/guides/reports/create-a-report/), add all relevant graphs and any textual descriptions or explanations you find pertinent, then download a PDF file (landscape format) and upload it along with the code to GitLab. Make sure to arrange the plots in a way that makes them understandable in the PDF (e.g., one graph per row, correct axes, etc.). Specify which report corresponds to which experiment.
### Full workflow with panda-gym
[Panda-gym](https://github.com/qgallouedec/panda-gym) is a collection of environments for robotic simulation and control. It provides a range of challenges for training robotic agents in a simulated environment. In this section, you will get familiar with one of the environments provided by panda-gym, the `PandaReachJointsDense-v3`. The objective is to learn how to reach any point in 3D space by directly controlling the robot's articulations.
#### Installation
```shell
pip install panda-gym==3.0.7
```
#### Train, track, and share
Use the Stable-Baselines3 package to train A2C model on the `PandaReachJointsDense-v3` environment. 500k timesteps should be enough. Track the environment with Weights & Biases. Once the training is over, upload the trained model on the Hub.
> 🛠 **To be handed in**
> Share all the code in `a2c_sb3_panda_reach.py`. Share the link of the wandb run and the trained model in the `README.md` file.
## Contribute
This tutorial may contain errors, inaccuracies, typos or areas for improvement. Feel free to contribute to its improvement by opening an issue.
## Author
Quentin Gallouédec
Updates by Bruno Machado, Léo Schneider, Emmanuel Dellandréa
## License
MIT
### Evaluation Results
Success Rate: 100.00%
import gymnasium as gym
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
# Create and wrap the environment
env_id = "CartPole-v1"
env = make_vec_env(env_id, n_envs=1)
# Initialize the A2C agent
model = A2C('MlpPolicy', env, verbose=1)
# Train the agent
model.learn(total_timesteps=10000)
# Save the trained model
model.save("a2c_sb3_cartpole")
# Evaluate the trained agent
obs = env.reset()
for _ in range(1000):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
# Close the environment
env.close()
import wandb
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
import gymnasium as gym
from gymnasium.envs.registration import register
import panda_gym
# Enregistrer manuellement l'environnement
register(
id='PandaReach-v3',
entry_point='panda_gym.envs:PandaReachEnv',
max_episode_steps=50,
)
# Initialiser l'exécution W&B
wandb.init(
project="panda-reach-experiments",
entity="hamza-masmoudi-central-lyon",
name="a2c-panda-reach-run-1",
config={
"algorithm": "A2C",
"environment": "PandaReach-v3",
"learning_rate": 0.0007,
"n_envs": 1,
"total_timesteps": 500000
}
)
# Créer un environnement vectorisé
env = make_vec_env("PandaReach-v3", n_envs=1)
# Initialiser le modèle avec MultiInputPolicy
model = A2C("MultiInputPolicy", env, verbose=1)
# Entraîner le modèle
model.learn(total_timesteps=500000)
# Sauvegarder le modèle
model.save("a2c_sb3_panda_reach.zip")
wandb.save("a2c_sb3_panda_reach.zip")
# Fermer l'environnement
env.close()
# Terminer l'exécution W&B
wandb.finish()
import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
# Define the neural network
class PolicyNetwork(nn.Module):
def __init__(self, input_dim, output_dim):
super(PolicyNetwork, self).__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.dropout = nn.Dropout(p=0.6)
self.fc2 = nn.Linear(128, output_dim)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return self.softmax(x)
# Function to load the model
def load_model(model, filename):
model.load_state_dict(torch.load(filename))
model.eval()
# Parameters
env = gym.make("CartPole-v1", render_mode="human")
input_dim = env.observation_space.shape[0]
output_dim = env.action_space.n
num_evaluations = 100
# Initialize the neural network
policy_net = PolicyNetwork(input_dim, output_dim)
load_model(policy_net, "reinforce_cartpole.pth")
# Evaluate the model
success_count = 0
for evaluation in range(num_evaluations):
state, _ = env.reset()
state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
done = False
total_reward = 0
while not done:
# Select an action
with torch.no_grad():
action_probs = policy_net(state)
action = torch.argmax(action_probs, dim=1).item()
# Apply the action and get the new state and reward
next_state, reward, terminated, truncated, info = env.step(action)
next_state = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0)
total_reward += reward
state = next_state
done = terminated or truncated
# Check if the episode was successful
if total_reward >= 195: # Consider an episode successful if the reward is 195 or more
success_count += 1
print(f"Evaluation {evaluation + 1}/{num_evaluations}, Total Reward: {total_reward}")
env.close()
# Calculate success rate
success_rate = success_count / num_evaluations
print(f"Success Rate: {success_rate * 100:.2f}%")
# Save the success rate to README.md
with open("README.md", "a") as f:
f.write(f"\n### Evaluation Results\nSuccess Rate: {success_rate * 100:.2f}%\n")
from huggingface_sb3 import push_to_hub
# Push the model to Hugging Face Hub
push_to_hub(
repo_id="whoshamza/panda-reach-model",
filename="a2c_sb3_panda_reach.zip",
commit_message="Upload A2C Panda Reach model"
)
\ No newline at end of file
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
# Define the neural network
class PolicyNetwork(nn.Module):
def __init__(self, input_dim, output_dim):
super(PolicyNetwork, self).__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.dropout = nn.Dropout(p=0.6)
self.fc2 = nn.Linear(128, output_dim)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return self.softmax(x)
# Function to save the model
def save_model(model, filename):
torch.save(model.state_dict(), filename)
# Parameters
env = gym.make("CartPole-v1")
input_dim = env.observation_space.shape[0]
output_dim = env.action_space.n
learning_rate = 0.001
gamma = 0.99
num_episodes = 500
# Initialize the neural network and optimizer
policy_net = PolicyNetwork(input_dim, output_dim)
optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)
# List to store total rewards per episode
total_rewards = []
# Training the model
for episode in range(num_episodes):
state, _ = env.reset() # Extract the observation from the tuple
state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
done = False
rewards = []
log_probs = []
while not done:
# Select an action
action_probs = policy_net(state)
action_dist = torch.distributions.Categorical(action_probs)
action = action_dist.sample()
log_prob = action_dist.log_prob(action)
log_probs.append(log_prob)
# Apply the action and get the new state and reward
next_state, reward, terminated, truncated, info = env.step(action.item())
next_state = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0)
rewards.append(reward)
state = next_state
done = terminated or truncated
# Compute returns
returns = []
R = 0
for reward in reversed(rewards):
R = reward + gamma * R
returns.insert(0, R)
returns = torch.tensor(returns)
# Normalize returns
returns = (returns - returns.mean()) / (returns.std() + 1e-9)
# Compute policy loss
policy_loss = []
for log_prob, R in zip(log_probs, returns):
policy_loss.append(-log_prob * R)
policy_loss = torch.cat(policy_loss).sum()
# Perform an optimization step
optimizer.zero_grad()
policy_loss.backward()
optimizer.step()
# Store total reward for this episode
total_rewards.append(sum(rewards))
if (episode + 1) % 10 == 0:
print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {total_rewards[-1]}")
# Save the model weights
save_model(policy_net, "reinforce_cartpole.pth")
# Plot total reward per episode
plt.plot(total_rewards)
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.title("Total Reward per Episode")
plt.savefig("reward_plot.png")
plt.show()
reward_plot.png

51.5 KiB

rigin 0 → 100644
commit 14c98a72cf6397e384cf9f0d1ef4f6461855d691 (HEAD -> main, origin/main, origin/HEAD)
Author: hamza masmoudi <hamza.masmoudi.ai@gmail.com>
Date: Mon Mar 17 22:03:43 2025 +0100
Ajout du fichier .gitignore et suppression des fichiers indésirables
commit 2478d377447a3b315cb6656c8f5383b8bcfc4366
Author: hamza masmoudi <hamza.masmoudi.ai@gmail.com>
Date: Mon Mar 17 22:01:01 2025 +0100
Initial commit with all project files
commit 099430d878d40986c4ac53542b400de5a994d373
Author: Emmanuel Dellandrea <emmanuel.dellandrea@ec-lyon.fr>
Date: Mon Feb 17 13:44:28 2025 +0100
Update README.md
commit 46ade8414c9c2fae9320c131ad54027eedefbd48
Author: Emmanuel Dellandrea <emmanuel.dellandrea@ec-lyon.fr>
Date: Mon Feb 17 13:36:27 2025 +0100
Update README.md
commit 33acdc2a1c56c9852d9bac336cc3d42f3d7566ad
Author: Emmanuel Dellandrea <emmanuel.dellandrea@ec-lyon.fr>
Date: Mon Feb 17 06:41:12 2025 +0100
Update README.md
commit 245e2c82f6d87e2541a2ae31d2e63f9f20f28837
Author: Emmanuel Dellandrea <emmanuel.dellandrea@ec-lyon.fr>
Date: Mon Feb 17 06:23:21 2025 +0100
Update README.md
commit 180efac47b22dc074ec5dbded070404003ae5555
Author: Emmanuel Dellandrea <emmanuel.dellandrea@ec-lyon.fr>
Date: Mon Feb 10 13:51:32 2025 +0100
Update README.md
commit 74c92be49d92c03b3b5210daf34b48c6b22956d9
Author: Emmanuel Dellandrea <emmanuel.dellandrea@ec-lyon.fr>
Date: Mon Feb 10 06:14:23 2025 +0100
Update README.md
commit ef969c9dea3fd5a4be9909d81b3f9b49248cb3e0
Author: Emmanuel Dellandrea <emmanuel.dellandrea@ec-lyon.fr>
Date: Tue Feb 6 12:29:36 2024 +0100
Update README.md
commit 840aba241c137f65008c2ad59d87d0feb117c6a5
Author: Emmanuel Dellandrea <emmanuel.dellandrea@ec-lyon.fr>
Date: Tue Feb 6 08:00:36 2024 +0100
Update README.md
commit 7902dbc19d08d5668b46183a3c5fe1c9a0b05c75
Author: Emmanuel Dellandrea <emmanuel.dellandrea@ec-lyon.fr>
Date: Tue Feb 6 07:48:28 2024 +0100
Update README.md
commit f70d479c12a1033737fe5e876a27e3289dbe87d6
Author: Emmanuel Dellandrea <emmanuel.dellandrea@ec-lyon.fr>
Date: Tue Feb 6 07:47:28 2024 +0100
Update README.md
commit 2e754025d1615e578b3a1880ce5df9b4fd166ebe
Author: Emmanuel Dellandrea <emmanuel.dellandrea@ec-lyon.fr>
Date: Tue Feb 6 07:42:15 2024 +0100
Update README.md
commit 5db001cf356108c396244a9572852cab988b2feb
Author: Dellandrea Emmanuel <emmanuel.dellandrea@ec-lyon.fr>
Date: Thu Dec 14 11:44:49 2023 +0000
Update README.md
commit 015a8684ea83b806bec130f6542dc875593a54b2
Author: Gallouedec Quentin <quentin.gallouedec@ec-lyon.fr>
Date: Mon Feb 6 12:15:42 2023 +0000
Update README.md
commit 3d3aca7e11ee8124132cb001a8ae3c636c5c0308
Author: Gallouedec Quentin <quentin.gallouedec@ec-lyon.fr>
Date: Mon Feb 6 12:04:06 2023 +0000
Complete drop of extra
commit c78ba997ee69d25f269f4d7b0c43af7e3f84a066
Author: Gallouedec Quentin <quentin.gallouedec@ec-lyon.fr>
Date: Mon Feb 6 12:03:22 2023 +0000
drop extra in sb3 install
commit b7f45be80e06d14c5b0b33021a5b0ed55fc7da9e
Author: Gallouedec Quentin <quentin.gallouedec@ec-lyon.fr>
Date: Mon Feb 6 11:39:16 2023 +0000
Install piglet with gym
commit f0cd812ad3ee51e6a92053d57243bd50f8704a05
Author: Gallouedec Quentin <quentin.gallouedec@ec-lyon.fr>
Date: Sun Feb 5 20:00:21 2023 +0000
link sb3 and wandb doc
commit e19a3a8707c4120ea7225a892674193505945f72
Merge: 4768fd97 b28160bc
Author: Gallouedec Quentin <quentin.gallouedec@ec-lyon.fr>
Date: Sun Feb 5 19:55:01 2023 +0000
Merge branch 'qgalloue-main-patch-75980' into 'main'
Add link to hf cours on reinforce
See merge request qgalloue/hands-on-rl!1
commit b28160bcdd3a949bd1c582ab6b012fc0a337f168
Author: Gallouedec Quentin <quentin.gallouedec@ec-lyon.fr>
Date: Sun Feb 5 19:54:51 2023 +0000
Add link to hf cours on reinforce
commit 4768fd978be9b04de1558627b17bdf0214e54a44
Author: Gallouedec Quentin <quentin.gallouedec@ec-lyon.fr>
Date: Sun Feb 5 19:47:10 2023 +0000
Add link to HF course
commit 37097153394ef42d6e8d5ae2faa949a291a56981
Author: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
Date: Fri Feb 3 14:29:23 2023 +0100
Minor rendering improvements
commit d914e5f2696879a45cf8486be01facd77f67d019
Author: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
Date: Fri Feb 3 13:07:00 2023 +0100
total reward and A2C for cartpole
commit 936f260f5af1d7fe9296faa257738a7d0c52530d
Author: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
Date: Fri Feb 3 11:59:16 2023 +0100
only store the chosen action prob in reinforce
commit 1c50ff9b8302da89b3e547c192e820451a4fe6cb
Author: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
Date: Fri Feb 3 11:36:56 2023 +0100
Fix CartPole capitalization
commit 63e0ece39c6325ad624f14e0098275d4277f699c
Author: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
Date: Wed Feb 1 19:38:43 2023 +0100
add missing pip install
commit d1d91428d451d7eb28a31863d222483407f2cd40
Author: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
Date: Wed Feb 1 19:35:45 2023 +0100
Initial version
commit fdee686e994cd80e9652d577eb592ccedc1d027b
Author: Gallouedec Quentin <quentin.gallouedec@ec-lyon.fr>
Date: Wed Feb 1 15:22:32 2023 +0000
Initial commit
from huggingface_sb3 import push_to_hub
from stable_baselines3 import A2C
from huggingface_hub import login
login()
# Load the trained model
model = A2C.load("a2c_sb3_cartpole")
# Save the model to a file
filename = "a2c_sb3_cartpole.zip"
model.save(filename)
# Define the repository ID
repo_id = "whoshamza/a2c-cartpole" # Replace with your Hugging Face username
# Push the saved model file to Hugging Face Hub
push_to_hub(repo_id=repo_id, filename=filename, commit_message="Upload A2C CartPole model")
import gymnasium as gym
try:
env = gym.make("PandaReachJointsDense-v3")
print("Environment created successfully!")
except Exception as e:
print(f"Error creating environment: {e}")
import wandb
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import BaseCallback
import gymnasium as gym
# Initialize W&B run
wandb.init(
project="cartpole-experiments",
entity="hamza-masmoudi-central-lyon",
name="a2c-cartpole-run-1",
config={
"algorithm": "A2C",
"environment": "CartPole-v1",
"learning_rate": 0.0007,
"n_envs": 1,
"total_timesteps": 10000
}
)
# Create vectorized environment
env = make_vec_env("CartPole-v1", n_envs=1)
# Custom callback to log episode rewards
class EpisodeRewardCallback(BaseCallback):
def __init__(self, verbose=0):
super(EpisodeRewardCallback, self).__init__(verbose)
self.episode_rewards = []
def _on_step(self) -> bool:
# Access the episode reward from the environment's info
if len(self.locals['infos']) > 0 and 'episode' in self.locals['infos'][0]:
episode_info = self.locals['infos'][0].get('episode', {})
episode_reward = episode_info.get('r')
if episode_reward is not None:
wandb.log({"episode_reward": episode_reward})
self.episode_rewards.append(episode_reward)
return True
# Initialize model
model = A2C("MlpPolicy", env, verbose=1)
# Train model with custom callback
model.learn(
total_timesteps=10000,
callback=EpisodeRewardCallback()
)
# Save the model
model.save("a2c_sb3_cartpole.zip")
wandb.save("a2c_sb3_cartpole.zip")
# Close environment
env.close()
# Finish the W&B run
wandb.finish()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment