diff --git a/README.md b/README.md index b47be3e415739e2939784f0a075278b1b12f1ca4..6a163fc44bebfe2a4ac585664b334f49755c4128 100644 --- a/README.md +++ b/README.md @@ -12,4 +12,11 @@ Although, with a bit of luck we end up with a model that reaches the max steps p ### Evaluation -During evaluation, we get a 100% success rate. \ No newline at end of file +During evaluation, we get a 100% success rate for 100 trials. + +## Familiarization with a complete RL pipeline: Application to training a robotic arm +We initialize the + +https://huggingface.co/Thomstr/A2C_CartPole/tree/main + +https://wandb.ai/thomasdgr-ecole-centrale-de-lyon/cartpole/runs/vh4anh20/workspace?nw=nwuserthomasdgr \ No newline at end of file diff --git a/a2c_sb3_cartpole.py b/a2c_sb3_cartpole.py new file mode 100644 index 0000000000000000000000000000000000000000..9a370e3a67325fd902898f59327aed6f8110c66e --- /dev/null +++ b/a2c_sb3_cartpole.py @@ -0,0 +1,49 @@ +from stable_baselines3 import A2C +from stable_baselines3.common.env_util import make_vec_env +from huggingface_hub import login +from huggingface_sb3 import push_to_hub +import wandb + + + + +def play(model): + obs = vec_env.reset() + while True: + action, _states = model.predict(obs) + obs, rewards, dones, info = vec_env.step(action) + vec_env.render("human") + +if __name__ == "__main__": + # start a new wandb run to track this script + config = { + "policy_type": "MlpPolicy", + "total_timesteps": 25000, + "env_name": "CartPole-v1", + } + wandb.login(key='4ac81e81b051a56ebfc528b579021cfc9ed1e5dc') + run = wandb.init( + project="cartpole", + config=config, + sync_tensorboard=True, + monitor_gym=True, + save_code=True, + ) + # Parallel environments + vec_env = make_vec_env("CartPole-v1", n_envs=4) + + model = A2C("MlpPolicy", vec_env, verbose=1) + model.learn(total_timesteps=25000) + model.save("a2c_cartpole") + + if False : + login(token="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") + push_to_hub( + repo_id="Thomstr/A2C_CartPole", + filename="a2c_cartpole.zip", + commit_message="Added A2C model for CartPole with Stable Baselines3", + ) + + + + diff --git a/a2c_sb3_panda_reach.py b/a2c_sb3_panda_reach.py new file mode 100644 index 0000000000000000000000000000000000000000..ec45523687c7c76469ed2df686b056272f80b405 --- /dev/null +++ b/a2c_sb3_panda_reach.py @@ -0,0 +1,45 @@ +import gym +import panda_gym +from stable_baselines3 import A2C +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.vec_env import DummyVecEnv +import wandb +from wandb.integration.sb3 import WandbCallback + + +config = { + "policy_type": "MultiInputPolicy", + "total_timesteps": 500000, + "env_name": "PandaReachJointsDense-v3", +} + +run = wandb.init( + project="pandareach", + config=config, + sync_tensorboard=True, + monitor_gym=True, + save_code=True, +) + +def make_env(): + env = gym.make(config["env_name"]) + env = Monitor(env) # record stats such as returns + return env + +env = DummyVecEnv([make_env]) +env = gym.make("PandaReachJointsDense-v3") +model = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}") +model.learn( + total_timesteps=config["total_timesteps"], + callback=WandbCallback( + ) +) + +run.finish() + +login(token="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") +push_to_hub( + repo_id="Thomstr/A2C_PandaReach", + filename="a2c_pandareach.zip", + commit_message="Added A2C model for PandaReach with Stable Baselines3", + ) \ No newline at end of file