import wandb, gymnasium as gym
from stable_baselines3 import A2C
from stable_baselines3.common.evaluation import evaluate_policy
from huggingface_hub import hf_api
from wandb.integration.sb3 import WandbCallback

# Setup the Cartpole environment
env = gym.make("CartPole-v1", render_mode="rgb_array")
# Choosing the model
model = A2C("MlpPolicy", env, verbose=1)
# Printing initial reward
reward_before_moy, _ = evaluate_policy(model, env, n_eval_episodes=10)
print(f"Mean reward before training: {reward_before_moy:.2f}")
# Model training during 10000 timesteps
model.learn(total_timesteps=10_000)
# Printing reward after training
reward_after_moy, _ = evaluate_policy(model, env, n_eval_episodes=10)
print(f"Mean reward after training: {reward_after_moy:.2f}")

# Upload and save model
# Saving the trained model
model_save_path = "model"
model.save(model_save_path)
model_path = "model.zip"
# Creating repository
repo_name="BE-RL"
rep = hf_api.create_repo(token="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", repo_id=repo_name)
# Uploading model in repository
repo_id="hchauvin78/BE-RL"
hf_api.upload_file(token="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", repo_id=repo_id, path_or_fileobj=model_path, path_in_repo=repo_name)


# Training with WandB
# Initializing WandB
wandb.init(project="cartpole-training", entity="hchauvin78", anonymous="allow")

#Configuring WandB
config = wandb.config
config.learning_rate = 0.001
config.gamma = 0.99
config.n_steps = 500

#Monitoring model training with WandB
model = A2C('MlpPolicy', env, verbose=1, tensorboard_log="logs/")
episode_rewards = []

for i in range(25000):  
    obs = env.reset()[0]
    reward_tot = 0
    terminated = False

    while terminated == False:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, terminated, info, _ = env.step(action)
        reward_tot += reward

    episode_rewards.append(reward_tot)
    wandb.log({"Episode Reward": reward_tot, "Episode": i})
    #Log mean reward every 10 episodes
    if i % 10 == 0:  
        mean_reward = sum(episode_rewards[-10:]) / 10
        wandb.log({"Mean Reward": mean_reward})


#Log final metrics to WandB
wandb.log({"Mean Reward": mean_reward})
#Finish WandB run
wandb.finish()