Skip to content
Snippets Groups Projects
Commit fd4364e4 authored by number_cruncher's avatar number_cruncher
Browse files

image

parent fa4e312e
Branches
No related tags found
No related merge requests found
...@@ -23,6 +23,8 @@ Now that you have trained your model, it is time to evaluate its performance. Ru ...@@ -23,6 +23,8 @@ Now that you have trained your model, it is time to evaluate its performance. Ru
From the openai gym wiki we know that the environment counts as solved when the average reward is greater or equal to 195 for over 100 consecutuve trials. From the openai gym wiki we know that the environment counts as solved when the average reward is greater or equal to 195 for over 100 consecutuve trials.
From the evaluation script i used the success rate is 1.0 when we allow the maximum number of steps the environment offers. From the evaluation script i used the success rate is 1.0 when we allow the maximum number of steps the environment offers.
![REINFORCE CartPole](reinforce_cartpole_dr_0.5.png)
## Familiarization with a complete RL pipeline: Application to training a robotic arm ## Familiarization with a complete RL pipeline: Application to training a robotic arm
Stable-Baselines3 (SB3) is a high-level RL library that provides various algorithms and integrated tools to easily train and test reinforcement learning models. Stable-Baselines3 (SB3) is a high-level RL library that provides various algorithms and integrated tools to easily train and test reinforcement learning models.
...@@ -36,7 +38,8 @@ Stable-Baselines3 (SB3) is a high-level RL library that provides various algorit ...@@ -36,7 +38,8 @@ Stable-Baselines3 (SB3) is a high-level RL library that provides various algorit
🛠 Share the link of the wandb run in the `README.md` file. 🛠 Share the link of the wandb run in the `README.md` file.
wandb: https://wandb.ai/lennartecl-centrale-lyon/sb3?nw=nwuserlennartecl wandb: https://wandb.ai/lennartecl-centrale-lyon/sb3?nw=nwuserlennartecl
hugging: https://huggingface.co/lennartoe/Cartpole-v1/tree/main
huggingface: https://huggingface.co/lennartoe/Cartpole-v1/tree/main
### Full workflow with panda-gym ### Full workflow with panda-gym
...@@ -46,5 +49,6 @@ hugging: https://huggingface.co/lennartoe/Cartpole-v1/tree/main ...@@ -46,5 +49,6 @@ hugging: https://huggingface.co/lennartoe/Cartpole-v1/tree/main
> Share all the code in `a2c_sb3_panda_reach.py`. Share the link of the wandb run and the trained model in the `README.md` file. > Share all the code in `a2c_sb3_panda_reach.py`. Share the link of the wandb run and the trained model in the `README.md` file.
wandb: https://wandb.ai/lennartecl-centrale-lyon/pandasgym_sb3?nw=nwuserlennartecl wandb: https://wandb.ai/lennartecl-centrale-lyon/pandasgym_sb3?nw=nwuserlennartecl
hugging: https://huggingface.co/lennartoe/PandaReachJointsDense-v3/tree/main
huggingface: https://huggingface.co/lennartoe/PandaReachJointsDense-v3/tree/main
...@@ -20,7 +20,7 @@ run = wandb.init( ...@@ -20,7 +20,7 @@ run = wandb.init(
save_code=True, save_code=True,
) )
env = gym.make("CartPole-v1", render_mode="rgb_array") env = gym.make("CartPole-v1")
model = A2C("MlpPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}") model = A2C("MlpPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}")
#model = A2C("MlpPolicy", env, ) #model = A2C("MlpPolicy", env, )
...@@ -31,7 +31,6 @@ obs = vec_env.reset() ...@@ -31,7 +31,6 @@ obs = vec_env.reset()
for i in range(1000): for i in range(1000):
action, _state = model.predict(obs, deterministic=True) action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action) obs, reward, done, info = vec_env.step(action)
vec_env.render("human")
run.finish() run.finish()
......
...@@ -10,7 +10,7 @@ from huggingface_sb3 import package_to_hub ...@@ -10,7 +10,7 @@ from huggingface_sb3 import package_to_hub
# from documentation of wandb # from documentation of wandb
config = { config = {
"policy_type": "MultiInputPolicy", "policy_type": "MultiInputPolicy",
"total_timesteps": 50000, "total_timesteps": 500000,
"env_name": "PandaReachJointsDense-v3", "env_name": "PandaReachJointsDense-v3",
} }
run = wandb.init( run = wandb.init(
...@@ -21,7 +21,7 @@ run = wandb.init( ...@@ -21,7 +21,7 @@ run = wandb.init(
save_code=True, save_code=True,
) )
env = gym.make("PandaReachJointsDense-v3", render_mode="rgb_array") env = gym.make("PandaReachJointsDense-v3")
model = A2C("MultiInputPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}") model = A2C("MultiInputPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}")
#model = A2C("MlpPolicy", env, ) #model = A2C("MlpPolicy", env, )
...@@ -32,10 +32,6 @@ obs = vec_env.reset() ...@@ -32,10 +32,6 @@ obs = vec_env.reset()
for i in range(1000): for i in range(1000):
action, _state = model.predict(obs, deterministic=True) action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action) obs, reward, done, info = vec_env.step(action)
vec_env.render("human")
# VecEnv resets automatically
# if done:
# obs = vec_env.reset()
run.finish() run.finish()
......
reinforce_cartpole_dr_0.5.png

38.2 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment