import numpy as np
from read_cifar import *
import matplotlib.pyplot as plt

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def learn_once_mse(w1, b1, w2, b2, data, targets, learning_rate):
    N = len(targets) # number of training examples
    
    # Forward pass
    a0 = data # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
    a1 = 1 / (1 + np.exp(-z1))  # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2  # input of the output layer
    a2 = 1 / (1 + np.exp(-z2))  # output of the output layer (sigmoid activation function)
    predictions = a2  # the predicted values are the outputs of the output layer
    
    # Compute loss (MSE)
    loss = np.mean(np.square(predictions - targets))
    
    # According to the formulas established by theory :
    d_a2 = 2 / N * (1 - targets)
    d_z2 = d_a2 * a2 * (1 - a2)
    d_w2 = np.matmul(a1.T, d_z2)
    d_b2 = d_z2
    d_a1 = np.matmul(d_z2, w2.T)
    d_z1 = d_a1 * a1 * (1 - a1)
    d_w1 = np.matmul(a0.T, d_z1)
    d_b1 = d_z1
    
    # Calculation of the updated weights and biases of the network with gradient descent method
    w1 -= learning_rate * d_w1
    w2 -= learning_rate * d_w2
    b2 -= learning_rate * d_b2
    b1 -= learning_rate * d_b1
    
    return w1, b1, w2, b2, loss


def one_hot(labels):
    # Total number of classes
    num_classes = np.max(labels) + 1
    # one_hot_matrix
    one_hot_matrix = np.eye(num_classes)[labels]
    return one_hot_matrix


def learn_once_cross_entropy(w1, b1, w2, b2, data, labels_train, learning_rate):
    N = len(labels_train) # number of training examples
    
    # Forward pass
    a0 = data # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
    a1 = 1 / (1 + np.exp(-z1))  # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2  # input of the output layer
    a2 = 1 / (1 + np.exp(-z2))  # output of the output layer (sigmoid activation function)
    predictions = a2  # the predicted values are the outputs of the output layer
    
    targets_one_hot = one_hot(labels_train) # target as a one-hot encoding for the desired labels
    
    # Cross-entropy loss
    loss = -np.sum(targets_one_hot * np.log(predictions)) / N
    
    # Backpropagation
    d_z2 = a2 - targets_one_hot
    d_w2 = np.dot(a1.T, d_z2) / N
    d_b2 = d_z2 / N
    d_a1 = np.dot(d_z2, w2.T)
    d_z1 = d_a1 * z1 * (1 - a1)
    d_w1 = np.dot(a0.T, d_z1) / N
    d_b1 = d_z1 / N
    
    # Calculation of the updated weights and biases of the network with gradient descent method
    w1 -= learning_rate * d_w1
    w2 -= learning_rate * d_w2
    b2 -= learning_rate * d_b2
    b1 -= learning_rate * d_b1
    
    return w1, b1, w2, b2, loss


def train_mlp(w1, b1, w2, b2, data_train, labels_train, learning_rate, num_epoch):
    train_accuracies = [0] * num_epoch
    for i in range(num_epoch):
        w1, b1, w2, b2, loss = learn_once_cross_entropy(w1, b1, w2, b2, data_train, labels_train, learning_rate)
        
        # Forward pass
        a0 = data_train # the data are the input of the first layer
        z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
        a1 = 1 / (1 + np.exp(-z1))  # output of the hidden layer (sigmoid activation function)
        z2 = np.matmul(a1, w2) + b2  # input of the output layer
        a2 = 1 / (1 + np.exp(-z2))  # output of the output layer (sigmoid activation function)
        predictions = a2  # the predicted values are the outputs of the output layer
        
        # Find the predicted class
        prediction = np.argmax(predictions, axis = 1)
        
        # Calculate the accuracy for the step
        accuracy = np.mean(labels_train == prediction)
        train_accuracies[i] = accuracy 
        
    return w1, b1, w2, b2, train_accuracies


def test_mlp(w1, b1, w2, b2, data_test, labels_test):
    
    # Forward pass
    a0 = data_test # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
    a1 = 1 / (1 + np.exp(-z1))  # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2  # input of the output layer
    a2 = 1 / (1 + np.exp(-z2))  # output of the output layer (sigmoid activation function)
    predictions = a2  # the predicted values are the outputs of the output layer
    
    # Find the predicted label
    prediction = np.argmax(predictions, axis = 1)
    
    # Calculation of the test accuracy
    test_accuracy = np.mean(prediction == labels_test)

    return test_accuracy
    

def run_mlp_training(data_train, labels_train, data_test, labels_test, d_h, learning_rate, num_epoch):
    
    # Define parameters
    d_in = data_train.shape[1] # number of input neurons
    d_out = len(np.unique(labels_train)) # number of output neurons = number of classes
    
    # Random initialization of the network weights and biaises
    w1 = 2 * np.random.rand(d_in, d_h) - 1  # first layer weights
    b1 = np.zeros((1, d_h))  # first layer biaises
    w2 = 2 * np.random.rand(d_h, d_out) - 1  # second layer weights
    b2 = np.zeros((1, d_out))  # second layer biaises
    
    # Training of the MLP classifier with num_epoch steps
    w1, b1, w2, b2, train_accuracies = train_mlp(w1, b1, w2, b2, data_train, labels_train, learning_rate, num_epoch)
    
    # Caculation of the final testing accuracy with the new values of the weights and bias
    test_accuracy = test_mlp(w1, b1, w2, b2, data_test, labels_test)

    return train_accuracies, test_accuracy


if __name__ == "__main__":
    
    split_factor = 0.9
    d_h = 64
    learning_rate = 0.1
    num_epoch = 100
    
    data, labels = read_cifar("./data/cifar-10-batches-py")
    data_train, labels_train, data_test, labels_test = split_dataset(data, labels, split_factor)
    
    epochs = [i for i in range(1, num_epoch + 1)]
    learning_accuracy = [0] * num_epoch
    
    for i in range(num_epoch) :
        train_accuracies, test_accuracy = run_mlp_training(data_train, labels_train, data_test, labels_test, d_h, learning_rate, i + 1)
        learning_accuracy[i] = test_accuracy
    
    plt.plot(epochs, learning_accuracy)
    plt.title("Evolution of learning accuracy across learning epochs")
    plt.xlabel("number of epochs")
    plt.ylabel("Accuracy")
    plt.grid(True, which='both')
    plt.savefig("results/mlp.png")