import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import gym
import torch
import torch.nn as nn
import torch.optim as optim


# Define the neural network for the policy
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.6)
        self.fc2 = nn.Linear(128, output_dim)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return self.softmax(x)


if __name__ == "__main__":
    
    # Hyperparameters
    learning_rate = 5e-3
    gamma = 0.99
    episodes = 450

    # Environment setup
    env = gym.make("CartPole-v1")  # , render_mode="human")

    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n

    # Policy network
    policy = PolicyNetwork(input_dim, output_dim)
    optimizer = optim.Adam(policy.parameters(), lr=learning_rate)

    # Training loop
    episode_rewards = []
    for episode in tqdm(range(episodes)):

        state = env.reset()[0]
        # --> here with my python environment I need to specify index [0]
        # but if I use another python environment for example with Google collab
        # I have to use the following script:
            # "state = env.reset()"

        saved_log_probs = []
        rewards = []

        while True:
            # Compute action probabilities
            state_tensor = torch.from_numpy(state).float().unsqueeze(0)
            action_probs = policy(state_tensor)

            # Sample action
            m = torch.distributions.Categorical(action_probs)
            action = m.sample()
            saved_log_probs.append(m.log_prob(action))
            state, reward, done, _, _ = env.step(action.item())
            # Step env with action
            rewards.append(reward)
            if done:
                break

        # Compute and normalize returns
        returns = torch.tensor(
            [sum(rewards[i:] * (0.99 ** np.arange(len(rewards) - i))) for i in range(len(rewards))]
        )
        returns = (returns - returns.mean()) / (returns.std() + 1e-9)

        # Compute policy loss and entropy loss
        policy_loss = -torch.stack(saved_log_probs).mul(returns).sum()
        entropy_loss = -0.01 * (action_probs * torch.log(action_probs)).sum(dim=1).mean()
        total_loss = policy_loss + entropy_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        episode_reward = sum(rewards)
        episode_rewards.append(episode_reward)
        

    # Plotting
    plt.plot(episode_rewards)
    plt.xlabel('Episode')
    plt.ylabel('Total reward')
    plt.title('REINFORCE on CartPole')
    plt.savefig('rewards_cartpole.png')
    plt.show()