diff --git a/README.md b/README.md index 773bba904c24b8eeeba9185443e9dfcc5ff57225..963a281344ca98ad2e0ece61deb14a4e4fd54386 100644 --- a/README.md +++ b/README.md @@ -25,4 +25,25 @@ We finally have an evaluation with 100% of sucess: Here we set up a complete pipeline to solve Cartpole environment with A2C algorithm. -Wandb has been set up to track the learning phase : [WandB tacking](https://wandb.ai/maximecerise-ecl/cartpole-a2c?nw=nwusermaximecerise) +Wandb has been set up to track the learning phase : [Report here](reports/A2C_CARTPOLE_REPORT.pdf) + +Preview : +<video controls src="videos/preview.mp4" title="Title"></video> + + + +### 3. Panda Reach + +Stable-Baselines3 package to train A2C model on the PandaReachJointsDense-v3 environment. 500k timesteps. + + +#### To run [a2c_sb3_panda_reach.py](a2c_sb3_panda_reach.py) : + +<code> pip install -r "requirement_reach.txt </code> +<code> python a2c_sb3_panda_reach.py </code> + +- <b>Code:</b> [a2c_sb3_panda_reach.py](a2c_sb3_panda_reach.py) +- <b>Hugging face :</b> [Here](https://huggingface.co/MaximeCerise/a2c-reach) +- <b>WandB's report : </b> [a2c_reach_panda_report](<reports/A2C_reach_panda_report .pdf>) +- <b>Preview :</b> +<video controls src="videos/preview_reach.mp4" title="Title"></video> diff --git a/a2c_sb3_panda_reach.py b/a2c_sb3_panda_reach.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf83a600e2f3f45d22fed37181c967b8212825b --- /dev/null +++ b/a2c_sb3_panda_reach.py @@ -0,0 +1,95 @@ +import gymnasium as gym +import panda_gym +import wandb +from wandb.integration.sb3 import WandbCallback +from huggingface_sb3 import load_from_hub, package_to_hub +from huggingface_hub import login +from stable_baselines3 import A2C +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder +def a2c_reach(): + config = { + "policy_type": "MultiInputPolicy", + "total_timesteps": 500000, + "env_name": "PandaReachJointsDense-v3", + } + + model_wb_name="a2c-reach" + + env_id=config["env_name"] + repo_user = "MaximeCerise" + model_push_name = "a2c-reach" + commit_message = "v7" + + run = wandb.init( + name=model_wb_name, + project="a2c-reach", + config=config, + sync_tensorboard=True, + monitor_gym=True, + save_code=True, + ) + + + env = gym.make(config["env_name"], render_mode="rgb_array") + env = Monitor(env) + env = DummyVecEnv([lambda:env]) + #env = VecVideoRecorder( + # env, + # f"wandb_data/videos/{run.id}", + # record_video_trigger=lambda x: x % 50000 == 0, + #video_length=100, + + #) + env = VecVideoRecorder( + env, + video_folder="videos", + record_video_trigger=lambda x: x == 0, + video_length=3000, + name_prefix="preview" + ) + + + model = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"wandb_data/runs/{run.id}") + model.learn(config["total_timesteps"], + progress_bar=True, + callback=WandbCallback( + gradient_save_freq=100, + model_save_path=f"wandb_data/models/{run.id}", + verbose=2, + )) + run.finish() + + + model.save(f"models/{model_push_name}") + + + mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100, deterministic=True) + print(f"mean_reward={mean_reward:.2f} +/- {std_reward}") + + + login('') + + eval_env = DummyVecEnv([lambda: gym.make(env_id, render_mode="rgb_array")]) + + package_to_hub(model=model, + model_name=model_push_name, + model_architecture="A2C", + env_id=env_id, + eval_env=eval_env, + repo_id=f"{repo_user}/{model_push_name}", + commit_message=commit_message) + + checkpoint = load_from_hub(f"{repo_user}/{model_push_name}", f"{model_push_name}.zip") + + model = A2C.load(checkpoint, print_system_info=True) + + obs = env.reset() + for i in range(2000): + action, _state = model.predict(obs, deterministic=True) + obs, reward, done, info = env.step(action) + env.render("human") + +if __name__ == "__main__": + a2c_reach() \ No newline at end of file diff --git a/reports/A2C_CARTPOLE_REPORT.pdf b/reports/A2C_CARTPOLE_REPORT.pdf new file mode 100644 index 0000000000000000000000000000000000000000..d16c921b6bc470350c5e16b9ae8c9ea94c6d5939 Binary files /dev/null and b/reports/A2C_CARTPOLE_REPORT.pdf differ diff --git a/reports/A2C_reach_panda_report .pdf b/reports/A2C_reach_panda_report .pdf new file mode 100644 index 0000000000000000000000000000000000000000..7edcaed59e5546fef020f2bc611684ed188299e5 Binary files /dev/null and b/reports/A2C_reach_panda_report .pdf differ diff --git a/requirements_reach.txt b/requirements_reach.txt new file mode 100644 index 0000000000000000000000000000000000000000..a7cc947541215649f8d504c3afc4299cc3baab2f --- /dev/null +++ b/requirements_reach.txt @@ -0,0 +1,9 @@ +panda-gym==3.0.7 +stable-baselines3 +wandb +huggingface-hub +gymnasium +huggingface_sb3 +tensorboard +moviepy +stable-baselines3[extra] \ No newline at end of file diff --git a/videos/rl-video-step-0-to-step-1000.mp4 b/videos/preview.mp4 similarity index 100% rename from videos/rl-video-step-0-to-step-1000.mp4 rename to videos/preview.mp4 diff --git a/videos/preview_reach.mp4 b/videos/preview_reach.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..b3937322a7d779f52cca48e7c73e26f624093593 Binary files /dev/null and b/videos/preview_reach.mp4 differ