Skip to content
Snippets Groups Projects
Commit 4a98fcec authored by oscarchaufour's avatar oscarchaufour
Browse files

update

parent fe50437f
No related branches found
No related tags found
No related merge requests found
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
runs
wandb
models
# fichier mac
.DS_Store
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
File added
...@@ -6,3 +6,8 @@ Link to the model on the hub : ...@@ -6,3 +6,8 @@ Link to the model on the hub :
# REINFORCE # REINFORCE
Plot showing the total reward accross episodes: ![Alt text](images/reinforce_rewards.png) Plot showing the total reward accross episodes: ![Alt text](images/reinforce_rewards.png)
# A2C trained model
Link to the trained model (available on huggingFace): https://huggingface.co/oscarchaufour/a2c-CartPole-v1
import gymnasium as gym import gym
import cv2 import cv2
from stable_baselines3 import A2C from stable_baselines3 import A2C
from huggingface_sb3 import package_to_hub, push_to_hub from huggingface_sb3 import package_to_hub, push_to_hub
from gym import envs from gym import envs
from gymnasium.envs.registration import register from gymnasium.envs.registration import register
from tqdm import tqdm from tqdm import tqdm
import matplotlib.pyplot as plt
import wandb
from wandb.integration.sb3 import WandbCallback
from stable_baselines3.common.vec_env import VecVideoRecorder
import dill
import zipfile
# Initialize Weights & Biases
total_timesteps = 10000
config = {
"policy_type": "MlpPolicy",
"total_timesteps": total_timesteps,
"env_name": "CartPole-v1",
}
wandb.login()
run = wandb.init(
project="a2c-cartpole-v1",
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True, # auto-upload the videos of agents playing the game
save_code=True, # optional
)
env_id = "CartPole-v1" env_id = "CartPole-v1"
# Register the environment # Register the environment
register(id=env_id, entry_point='gym.envs.classic_control:CartPoleEnv', max_episode_steps=500) register(id=env_id, entry_point='gym.envs.classic_control:CartPoleEnv', max_episode_steps=500)
env = gym.make(env_id) env = gym.make(env_id)
# env = VecVideoRecorder(
# env,
# f"videos/{run.id}",
# record_video_trigger=lambda x: x % 2000 == 0,
# video_length=200,
# )
model = A2C("MlpPolicy", env, verbose=1) model = A2C("MlpPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(total_timesteps=10_000) model.learn(total_timesteps=total_timesteps, callback=WandbCallback(
gradient_save_freq=100,
model_save_path=f"models/{run.id}"))
# Mark the run as public in W&B project settings
run.finish()
vec_env = model.get_env() vec_env = model.get_env()
obs = vec_env.reset() obs = vec_env.reset()
for i in tqdm(range(1000)): for i in tqdm(range(1000)):
action, _state = model.predict(obs, deterministic=True) action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action) obs, reward, done, info = vec_env.step(action)
vec_env.render("human") vec_env.render()
####### TO BE DONE #######
# # Serialize the model and save it to a .zip file
# import pickle
# import zipfile
# # Step 1: Serialize the model def save_model(model, env_id):
# model_bytes = pickle.dumps(model) # Step 1: Serialize the model
model_bytes = dill.dumps(model)
# # Step 2: Create a .zip file containing the serialized model # Step 2: Create a .zip file containing the serialized model
# zip_filename = env_id + ".zip" zip_filename = env_id + ".zip"
# with zipfile.ZipFile(zip_filename, 'w') as zipf: with zipfile.ZipFile(zip_filename, 'w') as zipf:
# zipf.writestr("model.pkl", model_bytes) zipf.writestr("model.pkl", model_bytes)
# # Package and push the model to the Hugging Face Hub
# model_package_id = package_to_hub(model=model,
# model_name="a2c-CartPole-v1",
# model_architecture="a2c",
# env_id=env_id,
# eval_env=env,
# repo_id="oscarchaufour/a2c-CartPole-v1",
# commit_message="Initial commit of A2C CartPole model")
# # Push the model package to the Hub
# push_to_hub(repo_id="oscarchaufour/a2c-CartPole-v1",
# filename=model_package_id + ".zip",
# commit_message="Added A2C CartPole model")
\ No newline at end of file
images/reinforce_rewards.png

35.7 KiB | W: | H:

images/reinforce_rewards.png

51.7 KiB | W: | H:

images/reinforce_rewards.png
images/reinforce_rewards.png
images/reinforce_rewards.png
images/reinforce_rewards.png
  • 2-up
  • Swipe
  • Onion skin
File deleted
...@@ -95,7 +95,7 @@ def plot_rewards(episodes_rewards): ...@@ -95,7 +95,7 @@ def plot_rewards(episodes_rewards):
if __name__ == "__main__": if __name__ == "__main__":
# Create the environment # Create the environment
env = gym.make("CartPole-v1") env = gym.make("CartPole-v1", render_mode="human")
# Set up the agent # Set up the agent
policy = Policy( policy = Policy(
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment