Skip to content
Snippets Groups Projects
Commit 089be291 authored by td's avatar td
Browse files

- add of a2c_sb3_panda_reach.ipynb (notebook version)

- removal of hf and wandb logs from code (for security)
- final version of the reamdme (last part added)
parent 54877b97
Branches
No related tags found
No related merge requests found
......@@ -25,9 +25,20 @@ see [a2c_sb3_cartpole.py](a2c_sb3_cartpole.py)
### Hugging Face Hub
[Link to the trained model](https://huggingface.co/Thomstr/A2C_CartPole/tree/main)
[Link to the trained model (cartpole)](https://huggingface.co/Thomstr/A2C_CartPole/tree/main)
### Weights & Biases
[Link to the wandb run](https://wandb.ai/thomasdgr-ecole-centrale-de-lyon/cartpole/runs/vh4anh20/workspace?nw=nwuserthomasdgr)
[Link to the wandb run (cartpole)](https://wandb.ai/thomasdgr-ecole-centrale-de-lyon/cartpole/runs/vh4anh20/workspace?nw=nwuserthomasdgr)
### Full workflow with panda-gym
see [a2c_sb3_panda_reach.py](a2c_sb3_panda_reach.py)
As I couldn't make it work on my PC (difficulties to install panda-gym), I've used Google Colab.
see my notebook [here (online)](https://colab.research.google.com/drive/1l03F398QLHHVVqJ-GvRgxA4d-cCocF4K?usp=sharing)
or directly [a2c_sb3_panda_reach.ipynb](a2c_sb3_panda_reach.ipynb)
[Link to the trained model (panda reach)](https://huggingface.co/Thomstr/A2C_PandaReach/tree/main)
[Link to the wandb run (panda reach)](https://wandb.ai/thomasdgr-ecole-centrale-de-lyon/pandareach/runs/y39cy9ws?nw=nwuserthomasdgr)
......@@ -21,7 +21,7 @@ if __name__ == "__main__":
"total_timesteps": 25000,
"env_name": "CartPole-v1",
}
wandb.login(key='4ac81e81b051a56ebfc528b579021cfc9ed1e5dc')
wandb.login(key='xxxxxxx')
run = wandb.init(
project="cartpole",
config=config,
......
Source diff could not be displayed: it is too large. Options to address this: view the blob.
import gym
import gymnasium as gym
import panda_gym
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
import wandb
from wandb.integration.sb3 import WandbCallback
from huggingface_hub import login
from huggingface_sb3 import push_to_hub
if __name__ == "__main__":
config = {
"policy_type": "MultiInputPolicy",
"total_timesteps": 500000,
"env_name": "PandaReachJointsDense-v3",
}
wandb.login(key='xxx')
run = wandb.init(
project="pandareach",
config=config,
......@@ -21,13 +26,7 @@ run = wandb.init(
save_code=True,
)
def make_env():
env = gym.make(config["env_name"])
env = Monitor(env) # record stats such as returns
return env
env = DummyVecEnv([make_env])
env = gym.make("PandaReachJointsDense-v3")
model = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(
total_timesteps=config["total_timesteps"],
......@@ -37,7 +36,9 @@ model.learn(
run.finish()
login(token="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")
model.save("a2c_pandareach")
login(token="xxx")
push_to_hub(
repo_id="Thomstr/A2C_PandaReach",
filename="a2c_pandareach.zip",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment