import wandb
import gymnasium as gym
import numpy as np
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
from huggingface_sb3 import package_to_hub

def a2c_sb3():
    env = gym.make("CartPole-v1", render_mode="rgb_array")
    env = Monitor(env)
    env = DummyVecEnv([lambda: env])

    wandb.init(
        entity="maximecerise-ecl",
        project="cartpole-a2c",
        sync_tensorboard=True,
        monitor_gym=True,
        save_code=True
    )


    model = A2C("MlpPolicy", env, verbose=1, tensorboard_log="./a2c_tensorboard/")
    model.learn(total_timesteps=300000)


    model.save("a2c_cartpole")
    env.close()


    eval_env = gym.make("CartPole-v1", render_mode="rgb_array")
    eval_env = Monitor(eval_env)
    eval_env = DummyVecEnv([lambda: eval_env])

    success_count = 0
    num_episodes = 100
    scores = []

    for episode in range(num_episodes):
        obs = eval_env.reset()
        done = False
        episode_reward = 0

        while not done:
            action, _ = model.predict(obs)
            obs, reward, done, info = eval_env.step(action)
            done = done[0]  # Extraction de la valeur booléenne
            episode_reward += reward

        scores.append(episode_reward)
        if episode_reward >= 200:
            success_count += 1


        wandb.log({
            "episode": episode,
            "episode_reward": episode_reward,
            "success_rate (%)": success_count / (episode + 1) * 100
        })


    success_rate = success_count / num_episodes * 100
    avg_score = np.mean(scores)


    wandb.log({
        "final_success_rate (%)": success_rate,
        "final_average_score": avg_score
    })

    print(f"Taux de succès du modèle A2C sur {num_episodes} épisodes : {success_rate:.2f}%")
    print(f"Score moyen : {avg_score:.2f}")


    video_folder = "./videos/"
    eval_env = VecVideoRecorder(eval_env, video_folder, record_video_trigger=lambda x: x == 0, video_length=1000)

    obs = eval_env.reset()
    for _ in range(1000):
        action, _ = model.predict(obs)
        obs, _, _, _ = eval_env.step(action)

    eval_env.close()


    package_to_hub(
        model=model,
        model_name="a2c_cartpole",
        model_architecture="A2C",
        env_id="CartPole-v1",
        eval_env=eval_env,
        repo_id="MaximeCerise/a2c_cartpole",
        commit_message="add a2c with evaluation"
    )


    wandb.finish()
if __name__ == "__main__":
    a2c_sb3()