Skip to content
Snippets Groups Projects
Commit b97cfa45 authored by oscarchaufour's avatar oscarchaufour
Browse files

update

parent fb37266a
No related branches found
No related tags found
No related merge requests found
import gym import gym
import cv2
from stable_baselines3 import A2C from stable_baselines3 import A2C
from huggingface_sb3 import package_to_hub, push_to_hub
from gym import envs
from gymnasium.envs.registration import register from gymnasium.envs.registration import register
from tqdm import tqdm from tqdm import tqdm
import matplotlib.pyplot as plt
import wandb import wandb
from wandb.integration.sb3 import WandbCallback from wandb.integration.sb3 import WandbCallback
from stable_baselines3.common.vec_env import VecVideoRecorder from stable_baselines3.common.vec_env import VecVideoRecorder
import dill from huggingface_sb3 import push_to_hub
import zipfile
def train_model(config, env_id, policy, project_name):
"""
Train a model using the A2C algorithm with Weights & Biases integration.
Args:
config (dict): Configuration parameters for training.
env_id (str): Identifier of the Gym environment.
policy (str): Type of policy to use for the model.
project_name (str): Name of the project in Weights & Biases.
Returns:
A2C: Trained A2C model.
"""
# Initialize Weights & Biases # Initialize Weights & Biases
total_timesteps = 10000
config = {
"policy_type": "MlpPolicy",
"total_timesteps": total_timesteps,
"env_name": "CartPole-v1",
}
wandb.login() wandb.login()
run = wandb.init( run = wandb.init(
project="a2c-cartpole-v1", project=project_name,
config=config, config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True, # auto-upload the videos of agents playing the game monitor_gym=True, # auto-upload the videos of agents playing the game
save_code=True, # optional save_code=True, # optional
) )
env_id = "CartPole-v1"
# Register the environment # Register the environment
register(id=env_id, entry_point='gym.envs.classic_control:CartPoleEnv', max_episode_steps=500) register(id=env_id, entry_point='gym.envs.classic_control:CartPoleEnv', max_episode_steps=500)
env = gym.make(env_id) env = gym.make(env_id, render_mode="rgb_array")
# env = VecVideoRecorder(
# env,
# f"videos/{run.id}",
# record_video_trigger=lambda x: x % 2000 == 0,
# video_length=200,
# )
model = A2C("MlpPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}") model = A2C(policy, env, verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(total_timesteps=total_timesteps, callback=WandbCallback( model.learn(total_timesteps=config["total_timesteps"])
gradient_save_freq=100,
model_save_path=f"models/{run.id}"))
# Mark the run as public in W&B project settings # Mark the run as public in W&B project settings
run.finish() run.finish()
return model
def test_model(model):
"""
Test a trained model by running it in the environment.
Args:
model (A2C): Trained A2C model to be tested.
"""
vec_env = model.get_env() vec_env = model.get_env()
obs = vec_env.reset() obs = vec_env.reset()
for i in tqdm(range(1000)): for _ in tqdm(range(1000)):
action, _state = model.predict(obs, deterministic=True) action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action) obs, reward, done, info = vec_env.step(action)
vec_env.render() vec_env.render("rgb_array")
def save_model(model, env_id): def save_push_model(model, project_name):
# Step 1: Serialize the model """
model_bytes = dill.dumps(model) Save the trained model and push it to the Hugging Face Model Hub.
# Step 2: Create a .zip file containing the serialized model Args:
zip_filename = env_id + ".zip" model (A2C): Trained A2C model.
with zipfile.ZipFile(zip_filename, 'w') as zipf: project_name (str): Name of the project to save the model.
zipf.writestr("model.pkl", model_bytes) """
model.save(project_name + ".zip")
#HugingFace
push_to_hub(
repo_id="oscarchaufour/a2c-CartPole-v1",
filename=project_name + ".zip",
commit_message="Adding CartPole model trained with A2C on HuggingFace",
token="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"
)
if __name__ == "__main__":
env_id = "CartPole-v1"
policy = "MlpPolicy"
config = {
"policy_type": policy,
"total_timesteps": 10000,
"env_name": env_id,
}
project_name = "a2c-CartPole-v1"
trained_model = train_model(config, env_id, policy, project_name)
test_model(trained_model)
save_push_model(trained_model, project_name)
import gym import gym
import panda_gym
from stable_baselines3 import A2C from stable_baselines3 import A2C
from huggingface_sb3 import package_to_hub, push_to_hub
from gym import envs
from gymnasium.envs.registration import register from gymnasium.envs.registration import register
from tqdm import tqdm from tqdm import tqdm
import matplotlib.pyplot as plt
import wandb import wandb
from wandb.integration.sb3 import WandbCallback from wandb.integration.sb3 import WandbCallback
from stable_baselines3.common.vec_env import VecVideoRecorder from stable_baselines3.common.vec_env import VecVideoRecorder
import dill from huggingface_sb3 import push_to_hub
import zipfile
def train_model(config, env_id, policy, project_name):
"""
Train a model using the A2C algorithm with Weights & Biases integration.
Args:
config (dict): Configuration parameters for training.
env_id (str): Identifier of the Gym environment.
policy (str): Type of policy to use for the model.
project_name (str): Name of the project in Weights & Biases.
Returns:
A2C: Trained A2C model.
"""
# Initialize Weights & Biases # Initialize Weights & Biases
total_timesteps = 100000
config = {
"policy_type": "MlpPolicy",
"total_timesteps": total_timesteps,
"env_name": "PandaReachJointsDense-v3",
}
wandb.login() wandb.login()
run = wandb.init( run = wandb.init(
project="a2c-PandaReachJointsDense-v3", project=project_name,
config=config, config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True, # auto-upload the videos of agents playing the game monitor_gym=True, # auto-upload the videos of agents playing the game
save_code=True, # optional save_code=True, # optional
) )
env_id = "PandaReachJointsDense-v3"
# Register the environment # Register the environment
register(id=env_id, entry_point='gym.envs.classic_control:CartPoleEnv', max_episode_steps=500) register(id=env_id, entry_point='gym.envs.robotics:PandaReachEnv', max_episode_steps=500)
env = gym.make(env_id) env = gym.make(env_id, render_mode="rgb_array")
model = A2C("MlpPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}") model = A2C(policy, env, verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(total_timesteps=total_timesteps, callback=WandbCallback( model.learn(total_timesteps=config["total_timesteps"])
gradient_save_freq=100,
model_save_path=f"models/{run.id}"))
# Mark the run as public in W&B project settings # Mark the run as public in W&B project settings
run.finish() run.finish()
return model
def test_model(model):
"""
Test a trained model by running it in the environment.
Args:
model (A2C): Trained A2C model to be tested.
"""
vec_env = model.get_env() vec_env = model.get_env()
obs = vec_env.reset() obs = vec_env.reset()
for i in tqdm(range(1000)): for _ in tqdm(range(1000)):
action, _state = model.predict(obs, deterministic=True) action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action) obs, reward, done, info = vec_env.step(action)
vec_env.render() vec_env.render("rgb_array")
def save_model(model, env_id): # use this function to save the model without wandb visualization def save_push_model(model, project_name):
# Step 1: Serialize the model """
model_bytes = dill.dumps(model) Save the trained model and push it to the Hugging Face Model Hub.
# Step 2: Create a .zip file containing the serialized model Args:
zip_filename = env_id + ".zip" model (A2C): Trained A2C model.
with zipfile.ZipFile(zip_filename, 'w') as zipf: project_name (str): Name of the project to save the model.
zipf.writestr("model.pkl", model_bytes) """
model.save(project_name + ".zip")
#HugingFace
push_to_hub(
repo_id="oscarchaufour/a2c-PandaReachJointsDense-v2",
filename=project_name + ".zip",
commit_message="Adding PandaReachJointsDense model trained with A2C on HuggingFace",
token="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"
)
if __name__ == "__main__":
env_id = "PandaReachJointsDense-v2"
policy = "MlpPolicy"
config = {
"policy_type": policy,
"total_timesteps": 500000,
"env_name": env_id,
}
project_name = "a2c-PandaReachJointsDense-v2"
trained_model = train_model(config, env_id, policy, project_name)
test_model(trained_model)
save_push_model(trained_model, project_name)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment