import numpy as np 
import matplotlib.pyplot as plt

def learn_once_mse(w1, b1, w2, b2, data, targets, learning_rate):

    # Forward pass
    a0 = data                    # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
    a1 = 1 / (1 + np.exp(-z1))   # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2  # input of the output layer
    a2 = 1 / (1 + np.exp(-z2))   # output of the output layer (sigmoid activation function)
    predictions = a2             # the predicted values are the outputs of the output layer

    # Compute loss (MSE)
    loss = np.mean(np.square(predictions - targets))

    N = data.shape[0]
    # Backward pass
    da2 = (2 / N) * (predictions - targets)
    dz2 = da2 * a2 * (1 - a2)

    dw2 = np.dot(a1.T, dz2) / N
    db2 = np.sum(dz2, axis=0, keepdims=True) / N
    
    da1 = np.dot(dz2, w2.T)
    dz1 = da1 * a1 * (1 - a1)   

    dw1 = np.dot(a0.T, dz1) / N
    db1 = np.sum(dz1, axis=0, keepdims=True) / N
    
    # One step of gradient descent
    w1 -= learning_rate * dw1
    w2 -= learning_rate * dw2
    b1 -= learning_rate * db1
    b2 -= learning_rate * db2

    return w1, b1, w2, b2, loss

def one_hot(x):
    """One hot encode a list of sample labels. Return a one-hot encoded vector for each label.
    """
    n_classes = 10
    return np.eye(n_classes)[x]

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    # Substracting the max value helps with numerical stability issues.
    exp = np.exp(x - np.max(x))
    return exp / np.sum(exp, axis=1, keepdims=True)

def learn_once_cross_entropy(w1, b1, w2, b2, data, targets, learning_rate):
    """
    Perform one forward and backward pass of an MLP using the cross-entropy loss.
    Returns:
        w1, b1, w2, b2 : the updated weights & biases of the MLP.
        loss : the loss
    """
    N = data.shape[0]

    # Forward pass
    a0 = data                       # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1     # input of the hidden layer
    a1 = 1 / (1 + np.exp(-z1))      # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2     # input of the output layer
    a2 = softmax(z2)                # output of the output layer (softmax activation function)
    predictions = a2                # the predicted values are the outputs of the output layer

    # One-hot encode the targets
    oh_targets = one_hot(targets)

    # Compute the Cross-Entropy loss (or Negative Likelihood Loss)
    loss = - np.sum(
        oh_targets * np.log(predictions + 1e-9)
        ) / N

    # Backward pass
    dz2 = predictions - oh_targets

    dw2 = np.dot(a1.T, dz2) / N
    db2 = np.sum(dz2, axis=0, keepdims=True) / N
    
    da1 = np.dot(dz2, w2.T)
    dz1 = da1 * a1 * (1 - a1)

    dw1 = np.dot(a0.T, dz1) / N
    db1 = np.sum(dz1, axis=0, keepdims=True) / N
    
    # One step of gradient descent
    w1 -= learning_rate * dw1
    w2 -= learning_rate * dw2
    b1 -= learning_rate * db1
    b2 -= learning_rate * db2

    return w1, b1, w2, b2, loss

def predict_mlp(w1, b1, w2, b2, data):
    """Do the forward pass of the MLP on data.
    Returns:
        numpy array: the predictions for images in data
    """
    # Forward pass
    a0 = data                    # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
    a1 = 1 / (1 + np.exp(-z1))   # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2  # input of the output layer
    a2 = softmax(z2)             # output of the output layer (softmax activation function)
    predictions = np.argmax(a2, axis=1)
    
    return predictions

def train_mlp(w1, b1, w2, b2, data_train, labels_train, learning_rate, num_epoch):
    """
    Perform num_epoch of training steps of the MLP using cross-entropy loss.
    Returns:
        w1, b1, w2, b2 : the updated weights & biases of the MLP after num_epoch of training steps.
        train_accuracies : list of train accuracies across epochs.
    """

    train_accuracies = [0] * num_epoch
    for epoch in range(num_epoch):
        w1, b1, w2, b2, loss = learn_once_cross_entropy(w1, b1, w2, b2, data_train, labels_train, learning_rate)
        labels_pred = predict_mlp(w1, b1, w2, b2, data_train)
        accuracy = np.mean(labels_pred == labels_train)
        train_accuracies[epoch] = accuracy

        print(f"Epoch loss [{epoch+1}/{num_epoch}] : {loss} --- accuracy : {accuracy}")

    return w1, b1, w2, b2, train_accuracies

# This function can't be named 'test_mlp' because pytest will think it's a test function and it will give an error
# thus I chose to name it 'Test_mlp'
def Test_mlp(w1, b1, w2, b2, data_test, labels_test):
    """Test the MLP on test data and compute the accuracy.
    Returns:
        float: test accuracy on data_test
    """
    labels_pred = predict_mlp(w1, b1, w2, b2, data_test)
    test_accuracy = np.mean(labels_pred == labels_test)
    
    return test_accuracy

def run_mlp_training(data_train, labels_train, data_test, labels_test, d_h, learning_rate, num_epoch):
    """
    Train a simple Neural Net with d_h hidden neurons and return the performance of the obtained model.
    Returns:
        train_accuracies (list): list of training accuracies over num_epoch steps.
        test_accuracy (float): the accuracy of the predictions of the trained model.
    """
    d_in = data_train.shape[1]
    d_out = len(set(labels_train))

    # Random initialization of the network weights and biaises
    w1 = 2 * np.random.rand(d_in, d_h) - 1  # first layer weights
    b1 = np.zeros((1, d_h))                 # first layer biaises
    w2 = 2 * np.random.rand(d_h, d_out) - 1 # second layer weights
    b2 = np.zeros((1, d_out))               # second layer biaises

    w1, b1, w2, b2, train_accuracies = train_mlp(w1, b1, w2, b2, data_train, labels_train, learning_rate, num_epoch)
    test_accuracy = Test_mlp(w1, b1, w2, b2, data_test, labels_test)

    return train_accuracies, test_accuracy

def plot_accuracy_versus_epoch(accuracies):
    """This function plots the variation of the accuracy asa function of k and saves the plot
    into /results.
    Args:
        accuracies (List): the list of accuracies for each value of k.
    """

    plt.figure(figsize=(18, 10))
    plt.plot(accuracies, 'o-b')
    plt.title("Variation of the accuracy over the epochs")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.grid(axis='both', which='both')
    plt.savefig(r'C:\Users\hp\Desktop\BE\image-classification\resultats\mlp1.png')