Skip to content
Snippets Groups Projects
Commit 32852677 authored by MaximeCerise's avatar MaximeCerise
Browse files

Finito

parent b4e270ab
No related branches found
No related tags found
No related merge requests found
...@@ -25,4 +25,25 @@ We finally have an evaluation with 100% of sucess: ...@@ -25,4 +25,25 @@ We finally have an evaluation with 100% of sucess:
Here we set up a complete pipeline to solve Cartpole environment with A2C algorithm. Here we set up a complete pipeline to solve Cartpole environment with A2C algorithm.
Wandb has been set up to track the learning phase : [WandB tacking](https://wandb.ai/maximecerise-ecl/cartpole-a2c?nw=nwusermaximecerise) Wandb has been set up to track the learning phase : [Report here](reports/A2C_CARTPOLE_REPORT.pdf)
Preview :
<video controls src="videos/preview.mp4" title="Title"></video>
### 3. Panda Reach
Stable-Baselines3 package to train A2C model on the PandaReachJointsDense-v3 environment. 500k timesteps.
#### To run [a2c_sb3_panda_reach.py](a2c_sb3_panda_reach.py) :
<code> pip install -r "requirement_reach.txt </code>
<code> python a2c_sb3_panda_reach.py </code>
- <b>Code:</b> [a2c_sb3_panda_reach.py](a2c_sb3_panda_reach.py)
- <b>Hugging face :</b> [Here](https://huggingface.co/MaximeCerise/a2c-reach)
- <b>WandB's report : </b> [a2c_reach_panda_report](<reports/A2C_reach_panda_report .pdf>)
- <b>Preview :</b>
<video controls src="videos/preview_reach.mp4" title="Title"></video>
import gymnasium as gym
import panda_gym
import wandb
from wandb.integration.sb3 import WandbCallback
from huggingface_sb3 import load_from_hub, package_to_hub
from huggingface_hub import login
from stable_baselines3 import A2C
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
def a2c_reach():
config = {
"policy_type": "MultiInputPolicy",
"total_timesteps": 500000,
"env_name": "PandaReachJointsDense-v3",
}
model_wb_name="a2c-reach"
env_id=config["env_name"]
repo_user = "MaximeCerise"
model_push_name = "a2c-reach"
commit_message = "v7"
run = wandb.init(
name=model_wb_name,
project="a2c-reach",
config=config,
sync_tensorboard=True,
monitor_gym=True,
save_code=True,
)
env = gym.make(config["env_name"], render_mode="rgb_array")
env = Monitor(env)
env = DummyVecEnv([lambda:env])
#env = VecVideoRecorder(
# env,
# f"wandb_data/videos/{run.id}",
# record_video_trigger=lambda x: x % 50000 == 0,
#video_length=100,
#)
env = VecVideoRecorder(
env,
video_folder="videos",
record_video_trigger=lambda x: x == 0,
video_length=3000,
name_prefix="preview"
)
model = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"wandb_data/runs/{run.id}")
model.learn(config["total_timesteps"],
progress_bar=True,
callback=WandbCallback(
gradient_save_freq=100,
model_save_path=f"wandb_data/models/{run.id}",
verbose=2,
))
run.finish()
model.save(f"models/{model_push_name}")
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100, deterministic=True)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
login('')
eval_env = DummyVecEnv([lambda: gym.make(env_id, render_mode="rgb_array")])
package_to_hub(model=model,
model_name=model_push_name,
model_architecture="A2C",
env_id=env_id,
eval_env=eval_env,
repo_id=f"{repo_user}/{model_push_name}",
commit_message=commit_message)
checkpoint = load_from_hub(f"{repo_user}/{model_push_name}", f"{model_push_name}.zip")
model = A2C.load(checkpoint, print_system_info=True)
obs = env.reset()
for i in range(2000):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render("human")
if __name__ == "__main__":
a2c_reach()
\ No newline at end of file
File added
File added
panda-gym==3.0.7
stable-baselines3
wandb
huggingface-hub
gymnasium
huggingface_sb3
tensorboard
moviepy
stable-baselines3[extra]
\ No newline at end of file
File moved
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment