import gymnasium as gym
import torch
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from reinforce_cartpole import PolicyNetwork

def evaluate_reinforce_cpole():
    env = gym.make("CartPole-v1", render_mode="human")
    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    # Charger le modèle
    policy = PolicyNetwork(obs_dim, action_dim)
    policy.load_state_dict(torch.load("saves/policy_cartpole.pth"))
    policy.eval()  # Mode évaluation

    num_episodes = 100
    success_threshold = 400  # Score moyen requis pour considérer que l'agent a appris
    success_count = 0
    scores = []

    for episode in range(num_episodes):
        state, _ = env.reset()
        state = torch.tensor(state, dtype=torch.float32)
        done = False
        total_reward = 0
        
        while not done:
            with torch.no_grad():
                action_probs = policy(state)
                action = torch.argmax(action_probs).item()  # Choisir l'action la plus probable
            
            next_state, reward, terminated, truncated, _ = env.step(action)
            total_reward += reward
            
            state = torch.tensor(next_state, dtype=torch.float32)
            done = terminated or truncated
        
        scores.append(total_reward)
        if total_reward >= success_threshold:
            success_count += 1
        
        print(f"Épisode {episode+1}: Score = {total_reward}")

    success_rate = success_count / num_episodes * 100
    print(f"\nSuccès: {success_count}/{num_episodes} ({success_rate:.2f}%)")

    env.close()

if __name__ == "__main__":
    evaluate_reinforce_cpole()