#
# For an unknown reason, I could not download and use panda_gym
# So I just put the code here, but I don't have any way to test it.
#

import os
import gymnasium as gym
import panda_gym
import wandb
from wandb.integration.sb3 import WandbCallback
from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
from stable_baselines3.common.evaluation import evaluate_policy
from huggingface_hub import login
from huggingface_sb3 import package_to_hub

# Configuration
config = {
    "policy_type": "MlpPolicy",
    "total_timesteps": 500000,
    "env_id": "PandaReachJointsDense-v3",
    "n_envs": 4
}

# Initialize wandb
run = wandb.init(
    project="sb3-panda-reach",
    entity="sim-ramos01-centrale-lyon",
    config=config,
    sync_tensorboard=True,
    monitor_gym=True,
    save_code=True,
)

# Create environment with video recording
def make_env():
    env = gym.make(config["env_id"], render_mode="rgb_array")
    env = Monitor(env)
    return env

# Create vectorized environment
env = DummyVecEnv([make_env for _ in range(config["n_envs"])])
env = VecVideoRecorder(env, f"videos/{run.id}",
                      record_video_trigger=lambda x: x % 2000 == 0,
                      video_length=200)

# Initialize the model
model = A2C(
    config["policy_type"],
    env,
    verbose=1,
    tensorboard_log=f"runs/{run.id}"
)

# Train the model
model.learn(
    total_timesteps=config["total_timesteps"],
    callback=WandbCallback(
        gradient_save_freq=100,
        model_save_path=f"models/{run.id}",
        verbose=2,
    )
)

# Save the model
model_name = f"a2c_panda_reach_{run.id}"
model.save(model_name)

# Evaluate the model
eval_env = gym.make(config["env_id"], render_mode="rgb_array")
mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10)
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")

# Log final metrics
wandb.log({
    "eval/mean_reward": mean_reward,
    "eval/std_reward": std_reward
})

# Upload to Hugging Face Hub
repo_id = "SimRams/a2c-panda-reach"
package_to_hub(
    model=model,
    model_name=model_name,
    model_architecture="A2C",
    env_id=config["env_id"],
    eval_env=eval_env,
    repo_id=repo_id,
    commit_message="Training A2C on PandaReachJointsDense-v3"
)

# Cleanup
env.close()
eval_env.close()
wandb.finish()
