from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from huggingface_hub import login
from huggingface_sb3 import package_to_hub, push_to_hub


# Parallel environments
vec_env = make_vec_env("CartPole-v1", n_envs=4)

model = A2C("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=250, callback=WandbCallback(
        gradient_save_freq=100,
        model_save_path=f"models/{run.id}",
        verbose=2,
    ))

# model.save("a2c_cartpole")

# model = A2C.load("a2c_cartpole")

# obs = vec_env.reset()
# while True:
#     action, _states = model.predict(obs)
#     obs, rewards, dones, info = vec_env.step(action)
#     vec_env.render("human")

# login()

# package_to_hub(model=model, 
#                model_name="a2c_sb3_cartpole",
#                model_architecture="a2c_sb3_cartpole",
#                env_id="CartPole-v1",
#                eval_env=vec_env,
#                repo_id="SimRams/a2c_sb3_cartpole",
#                commit_message="Test commit")

# push_to_hub(
#     repo_id="SimRams/a2c_sb3_cartpole",
#     filename="a2c_cartpole.zip",
#     commit_message="Test commit",
# )