diff --git a/a2c_sb3_cartpole.py b/a2c_sb3_cartpole.py index 3666845bf8010fd247e855d9f0b60b8487ebc311..d5090b4b11bf11d82f20f0e94b7a2a883e29e44c 100644 --- a/a2c_sb3_cartpole.py +++ b/a2c_sb3_cartpole.py @@ -4,6 +4,7 @@ from stable_baselines3 import A2C from huggingface_sb3 import package_to_hub, push_to_hub from gym import envs from gymnasium.envs.registration import register +from tqdm import tqdm env_id = "CartPole-v1" @@ -17,24 +18,36 @@ model.learn(total_timesteps=10_000) vec_env = model.get_env() obs = vec_env.reset() -for i in range(1000): +for i in tqdm(range(1000)): action, _state = model.predict(obs, deterministic=True) obs, reward, done, info = vec_env.step(action) vec_env.render("human") - # VecEnv resets automatically - # if done: - # obs = vec_env.reset() - -# Package and push the model to the Hugging Face Hub -model_package_id = package_to_hub(model=model, - model_name="a2c-CartPole-v1", - model_architecture="a2c", - env_id=env_id, - eval_env=env, - repo_id="oscarchaufour/a2c-CartPole-v1", - commit_message="Initial commit of A2C CartPole model") - -# Push the model package to the Hub -push_to_hub(repo_id="oscarchaufour/a2c-CartPole-v1", - filename=model_package_id + ".zip", - commit_message="Added A2C CartPole model") \ No newline at end of file + + +####### TO BE DONE ####### + +# # Serialize the model and save it to a .zip file +# import pickle +# import zipfile + +# # Step 1: Serialize the model +# model_bytes = pickle.dumps(model) + +# # Step 2: Create a .zip file containing the serialized model +# zip_filename = env_id + ".zip" +# with zipfile.ZipFile(zip_filename, 'w') as zipf: +# zipf.writestr("model.pkl", model_bytes) + +# # Package and push the model to the Hugging Face Hub +# model_package_id = package_to_hub(model=model, +# model_name="a2c-CartPole-v1", +# model_architecture="a2c", +# env_id=env_id, +# eval_env=env, +# repo_id="oscarchaufour/a2c-CartPole-v1", +# commit_message="Initial commit of A2C CartPole model") + +# # Push the model package to the Hub +# push_to_hub(repo_id="oscarchaufour/a2c-CartPole-v1", +# filename=model_package_id + ".zip", +# commit_message="Added A2C CartPole model") \ No newline at end of file