Skip to content
Snippets Groups Projects
Commit fe50437f authored by oscarchaufour's avatar oscarchaufour
Browse files

Update a2c_sb3_cartpole.py

parent f89b6348
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,7 @@ from stable_baselines3 import A2C
from huggingface_sb3 import package_to_hub, push_to_hub
from gym import envs
from gymnasium.envs.registration import register
from tqdm import tqdm
env_id = "CartPole-v1"
......@@ -17,24 +18,36 @@ model.learn(total_timesteps=10_000)
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
for i in tqdm(range(1000)):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render("human")
# VecEnv resets automatically
# if done:
# obs = vec_env.reset()
# Package and push the model to the Hugging Face Hub
model_package_id = package_to_hub(model=model,
model_name="a2c-CartPole-v1",
model_architecture="a2c",
env_id=env_id,
eval_env=env,
repo_id="oscarchaufour/a2c-CartPole-v1",
commit_message="Initial commit of A2C CartPole model")
# Push the model package to the Hub
push_to_hub(repo_id="oscarchaufour/a2c-CartPole-v1",
filename=model_package_id + ".zip",
commit_message="Added A2C CartPole model")
\ No newline at end of file
####### TO BE DONE #######
# # Serialize the model and save it to a .zip file
# import pickle
# import zipfile
# # Step 1: Serialize the model
# model_bytes = pickle.dumps(model)
# # Step 2: Create a .zip file containing the serialized model
# zip_filename = env_id + ".zip"
# with zipfile.ZipFile(zip_filename, 'w') as zipf:
# zipf.writestr("model.pkl", model_bytes)
# # Package and push the model to the Hugging Face Hub
# model_package_id = package_to_hub(model=model,
# model_name="a2c-CartPole-v1",
# model_architecture="a2c",
# env_id=env_id,
# eval_env=env,
# repo_id="oscarchaufour/a2c-CartPole-v1",
# commit_message="Initial commit of A2C CartPole model")
# # Push the model package to the Hub
# push_to_hub(repo_id="oscarchaufour/a2c-CartPole-v1",
# filename=model_package_id + ".zip",
# commit_message="Added A2C CartPole model")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment