Skip to content
Snippets Groups Projects
Commit da1ca4a3 authored by td's avatar td
Browse files

avancement

parent f0926141
No related branches found
No related tags found
No related merge requests found
...@@ -12,4 +12,11 @@ Although, with a bit of luck we end up with a model that reaches the max steps p ...@@ -12,4 +12,11 @@ Although, with a bit of luck we end up with a model that reaches the max steps p
### Evaluation ### Evaluation
During evaluation, we get a 100% success rate. During evaluation, we get a 100% success rate for 100 trials.
\ No newline at end of file
## Familiarization with a complete RL pipeline: Application to training a robotic arm
We initialize the
https://huggingface.co/Thomstr/A2C_CartPole/tree/main
https://wandb.ai/thomasdgr-ecole-centrale-de-lyon/cartpole/runs/vh4anh20/workspace?nw=nwuserthomasdgr
\ No newline at end of file
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from huggingface_hub import login
from huggingface_sb3 import push_to_hub
import wandb
def play(model):
obs = vec_env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
if __name__ == "__main__":
# start a new wandb run to track this script
config = {
"policy_type": "MlpPolicy",
"total_timesteps": 25000,
"env_name": "CartPole-v1",
}
wandb.login(key='4ac81e81b051a56ebfc528b579021cfc9ed1e5dc')
run = wandb.init(
project="cartpole",
config=config,
sync_tensorboard=True,
monitor_gym=True,
save_code=True,
)
# Parallel environments
vec_env = make_vec_env("CartPole-v1", n_envs=4)
model = A2C("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=25000)
model.save("a2c_cartpole")
if False :
login(token="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")
push_to_hub(
repo_id="Thomstr/A2C_CartPole",
filename="a2c_cartpole.zip",
commit_message="Added A2C model for CartPole with Stable Baselines3",
)
import gym
import panda_gym
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
import wandb
from wandb.integration.sb3 import WandbCallback
config = {
"policy_type": "MultiInputPolicy",
"total_timesteps": 500000,
"env_name": "PandaReachJointsDense-v3",
}
run = wandb.init(
project="pandareach",
config=config,
sync_tensorboard=True,
monitor_gym=True,
save_code=True,
)
def make_env():
env = gym.make(config["env_name"])
env = Monitor(env) # record stats such as returns
return env
env = DummyVecEnv([make_env])
env = gym.make("PandaReachJointsDense-v3")
model = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(
total_timesteps=config["total_timesteps"],
callback=WandbCallback(
)
)
run.finish()
login(token="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")
push_to_hub(
repo_id="Thomstr/A2C_PandaReach",
filename="a2c_pandareach.zip",
commit_message="Added A2C model for PandaReach with Stable Baselines3",
)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment