Skip to content
Snippets Groups Projects
Commit 789865bb authored by Majdi Karim's avatar Majdi Karim
Browse files

Add new file

parent 93ba5cfa
No related branches found
No related tags found
No related merge requests found
### LIBRARIES
import gymnasium as gym
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
import wandb
from wandb.integration.sb3 import WandbCallback
from huggingface_sb3 import push_to_hub
import panda_gym
import os
from huggingface_hub import login
#dir_path = os.path.dirname(os.path.realpath(__file__))
#os.chdir(dir_path)
config = {
"policy_type": "MultiInputPolicy",
"total_timesteps": 250000,
"env_name": "PandaReachJointsDense-v3",
}
run = wandb.init(
project="sb3-panda-reach",
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True, # auto-upload the videos of agents playing the game
save_code=True, # optional
)
def make_env():
env = gym.make(config["env_name"])
env = Monitor(env) # record stats such as returns
return env
env = DummyVecEnv([make_env])
# env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % 2000 == 0, video_length=200)
model = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(
total_timesteps=config["total_timesteps"],
callback=WandbCallback(
gradient_save_freq=100,
model_save_path=f"models/{run.id}",
verbose=2,
),
)
run.finish()
login(token="*********")
# Save the trained model
model.save("ECL-TD-RL1-a2c_panda_reach.zip")
# Load the trained model
model = A2C.load("ECL-TD-RL1-a2c_panda_reach.zip")
push_to_hub(
repo_id="Karim-20/a2c_cartpole",
filename="ECL-TD-RL1-a2c_panda_reach.zip",
commit_message="Add PandaReachJointsDense-v2 environement, agent used to train is A2C"
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment