Skip to content
Snippets Groups Projects
Commit d1fd9406 authored by Ghelfi Manon's avatar Ghelfi Manon
Browse files

Update README.md

parent 970816e7
Branches
No related tags found
No related merge requests found
...@@ -12,7 +12,6 @@ Le graphique de l'évolution des recompenses totales aux cours des épisodes est ...@@ -12,7 +12,6 @@ Le graphique de l'évolution des recompenses totales aux cours des épisodes est
Le fichier a2c_sb3_cartpole.py comporte un model pour resoudre le problème du CartPole en utilisant un algorithme Advantage Actor-Critic (A2C) grace à la bilbiothèque Stable-Baselines3. Le fichier a2c_sb3_cartpole.py comporte un model pour resoudre le problème du CartPole en utilisant un algorithme Advantage Actor-Critic (A2C) grace à la bilbiothèque Stable-Baselines3.
## Hugging Face Hub ## Hugging Face Hub
**(TODO: verifier pour avoir plus de trucs)**
https://huggingface.co/manonghelfi/a2c_cartpole/tree/main https://huggingface.co/manonghelfi/a2c_cartpole/tree/main
J'ai téléchargé mon model sur huggingface avec les commandes python suivantes : J'ai téléchargé mon model sur huggingface avec les commandes python suivantes :
...@@ -31,22 +30,31 @@ push_to_hub( ...@@ -31,22 +30,31 @@ push_to_hub(
Aprés mettre identifié grace à la commande : `huggingface-cli login` Aprés mettre identifié grace à la commande : `huggingface-cli login`
## Weights & Biases ## Weights & Biases
**(TODO: trouver comment mettre des données utiles)** Le run du model est présent ici : https://wandb.ai/ghelfi/cartpole-training/runs/06exlpbm
Réalisé grace au code ci dessous :
``` ```
import wandb import wandb
wandb.init(project='cartpole-training')
import gym import gym
from stable_baselines3 import A2C from stable_baselines3 import A2C
import numpy as np
env = gym.make("CartPole-v1")
wandb.init(project='a2c_CartPole')
env = gym.make('CartPole-v1')
model = A2C("MlpPolicy", env, verbose=1) model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=500000) model.learn(total_timesteps=10000)
observations = [env.reset() for _ in range(100)] rewards = []
actions, _states = model.predict(observations) obs = env.reset()
accuracy = sum([a == env.action_space.label[i] for i, a in enumerate(actions)]) / len(actions) while True:
wandb.log({"model": model, 'accuracy':accuracy}) action, _states = model.predict(obs)
model.save("a2c_CartPole") obs, reward, done, info = env.step(action)
rewards.append(reward)
if done:
break
print("Mean Reward: ", np.mean(rewards))
wandb.log({'reward_mean': np.mean(rewards)})
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment