Skip to content
Snippets Groups Projects
Commit da1ca4a3 authored by td's avatar td
Browse files

avancement

parent f0926141
Branches
No related tags found
No related merge requests found
......@@ -12,4 +12,11 @@ Although, with a bit of luck we end up with a model that reaches the max steps p
### Evaluation
During evaluation, we get a 100% success rate.
\ No newline at end of file
During evaluation, we get a 100% success rate for 100 trials.
## Familiarization with a complete RL pipeline: Application to training a robotic arm
We initialize the
https://huggingface.co/Thomstr/A2C_CartPole/tree/main
https://wandb.ai/thomasdgr-ecole-centrale-de-lyon/cartpole/runs/vh4anh20/workspace?nw=nwuserthomasdgr
\ No newline at end of file
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from huggingface_hub import login
from huggingface_sb3 import push_to_hub
import wandb
def play(model):
obs = vec_env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
if __name__ == "__main__":
# start a new wandb run to track this script
config = {
"policy_type": "MlpPolicy",
"total_timesteps": 25000,
"env_name": "CartPole-v1",
}
wandb.login(key='4ac81e81b051a56ebfc528b579021cfc9ed1e5dc')
run = wandb.init(
project="cartpole",
config=config,
sync_tensorboard=True,
monitor_gym=True,
save_code=True,
)
# Parallel environments
vec_env = make_vec_env("CartPole-v1", n_envs=4)
model = A2C("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=25000)
model.save("a2c_cartpole")
if False :
login(token="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")
push_to_hub(
repo_id="Thomstr/A2C_CartPole",
filename="a2c_cartpole.zip",
commit_message="Added A2C model for CartPole with Stable Baselines3",
)
import gym
import panda_gym
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
import wandb
from wandb.integration.sb3 import WandbCallback
config = {
"policy_type": "MultiInputPolicy",
"total_timesteps": 500000,
"env_name": "PandaReachJointsDense-v3",
}
run = wandb.init(
project="pandareach",
config=config,
sync_tensorboard=True,
monitor_gym=True,
save_code=True,
)
def make_env():
env = gym.make(config["env_name"])
env = Monitor(env) # record stats such as returns
return env
env = DummyVecEnv([make_env])
env = gym.make("PandaReachJointsDense-v3")
model = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(
total_timesteps=config["total_timesteps"],
callback=WandbCallback(
)
)
run.finish()
login(token="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")
push_to_hub(
repo_id="Thomstr/A2C_PandaReach",
filename="a2c_pandareach.zip",
commit_message="Added A2C model for PandaReach with Stable Baselines3",
)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment