import gymnasium as gym

from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv

import wandb
import panda_gym
from wandb.integration.sb3 import WandbCallback
from huggingface_hub import login
from huggingface_sb3 import push_to_hub

if __name__ == "__main__":
    # Log in HF
    login()

    # Initialize a new wandb run
    # Configs
    config = {
        "policy_type": "MultiInputPolicy",
        "total_timesteps": 500_000,
        "env_name": "PandaReachJointsDense-v3",
    }

    # Initialize a new wandb run
    run = wandb.init(
        project="a2c_sb3_panda_reach",
        config=config,
        sync_tensorboard=True,
        monitor_gym=True,
    )

    def make_env():
        env = gym.make(config["env_name"])
        env = Monitor(env)  # Record stats such as returns
        return env

    env = DummyVecEnv([make_env])
    model = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
    
    # WB callback
    wandb_callback = WandbCallback(
        gradient_save_freq=100,
        model_save_path=f"models/{run.id}",
        verbose=2,
    )

    model.learn(
        total_timesteps=config["total_timesteps"],
        callback=wandb_callback
    )
    model.save("PandaReachJointsDense_1.zip")
    
    # Finish the run
    run.finish()

    # Upload on HF
    push_to_hub(
        repo_id="CorentinGst/PandaReachJointsDense_1",
        filename="PandaReachJointsDense_1.zip",
        commit_message="Add my 1st model trained on PandaReachJointsDense-v3 env",
    )