import numpy as np
from read_cifar import read_cifar
from read_cifar import read_cifar_preprocess
from read_cifar import split_dataset
import matplotlib.pyplot as plt
import math as ma


def sigmoid(x):  
    return np.exp(-np.logaddexp(0, -x))


def learn_once_mse(w1,b1,w2,b2,data,targets,learning_rate):
    labels_list = []
    for label in targets:
        if not label in labels_list:
            labels_list.append(label)
    d_out = len(labels_list)

    # Forward pass
    a0 = data # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
    a1 = 1 / (1 + np.exp(-z1))  # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2  # input of the output layer
    a2 = 1 / (1 + np.exp(-z2))  # output of the output layer (sigmoid activation function)
    predictions = a2/np.sum(a2)  # the predicted values are the outputs of the output layer

    targets_one_hot = one_hot(targets)

    # Compute loss (MSE)
    loss = np.mean(np.square(predictions - targets_one_hot))


    # Backpropagation

    dc_da2 = (2/d_out) * (a2 - targets_one_hot)
    dc_dz2 = dc_da2 * a2 * (1 - a2)
    dc_dw2 = np.dot(a1.T, dc_dz2)
    dc_db2 = np.sum(dc_dz2, axis=0)
    dc_da1 = np.dot(dc_dz2, w2.T)
    dc_dz1 = dc_da1 * a1 * (1 - a1)
    dc_dw1 = np.dot(a0.T, dc_dz1)
    dc_db1 = np.sum(dc_dz1, axis=0)
    
    w1 -= dc_dw1*learning_rate
    w2 -= dc_dw2*learning_rate
    b1 -= dc_db1*learning_rate
    b2 -= dc_db2*learning_rate

    return w1,b1,w2,b2,loss


def one_hot(array):
    labels_list = []
    for label in array:
        if not label in labels_list:
            labels_list.append(label)
    encoding = np.zeros(array.shape + (len(labels_list),))
    for index,element in np.ndenumerate(array):
        new_one_coord=index+(labels_list.index(element),)
        encoding[new_one_coord] = 1
    return encoding


def calculate_accuracy(predictions,labels):

    correct_counter=0
    for i in range(len(predictions)):
        if labels[i]==np.argmax(predictions[i]):
            correct_counter+=1
    return correct_counter/len(predictions)


def softmax(x):
    exp_x = np.exp(x - np.max(x, axis=1, keepdims=True))
    return exp_x / np.sum(exp_x, axis=1, keepdims=True)


def learn_once_cross_entropy(w1,b1,w2,b2,data,labels_train,learning_rate):
    

    Y= one_hot(labels_train)
    
    # Forward pass
    a0 = data # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
    a1 = 1 / (1 + np.exp(-z1))  # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2  # input of the output layer
    a2 = softmax(z2)  # output of the output layer (softmax activation function)
    predictions = a2 # the predicted values are the outputs of the output layer

    #accuracy for tracking progress
    accuracy = calculate_accuracy(predictions,labels_train)

    
    loss = -np.sum(Y * np.log(predictions + 1e-8)) / predictions.shape[0]


    # Backpropagation, admitting the formula for dc_dz2
    
    dc_dz2 = (a2 - Y) / a2.shape[0]  # Normalize by batch size
    dc_dw2 = np.dot(a1.T, dc_dz2)
    dc_db2 = np.sum(dc_dz2, axis=0)
    dc_da1 = np.dot(dc_dz2, w2.T)
    dc_dz1 = dc_da1 * a1 * (1 - a1)
    dc_dw1 = np.dot(a0.T, dc_dz1)
    dc_db1 = np.sum(dc_dz1, axis=0)

    
    w1 -= dc_dw1*learning_rate
    w2 -= dc_dw2*learning_rate
    b1 -= dc_db1*learning_rate
    b2 -= dc_db2*learning_rate
    



    return w1,b1,w2,b2,accuracy


def train_mlp(w1,b1,w2,b2,data,labels_train,learning_rate,num_epoch,loss):
    accuracies = []
    for n in range(num_epoch):
        if  loss == "cross-entropy":
            w1,b1,w2,b2,acc=learn_once_cross_entropy( w1,b1,w2,b2,data,labels_train,learning_rate)
        elif loss == "mean square root":
            w1,b1,w2,b2,acc=learn_once_mse( w1,b1,w2,b2,data,labels_train,learning_rate)
        accuracies.append(acc)
        print(f"epoch {n}, training accuracy : {acc}")
    return w1,b1,w2,b2,accuracies


def test_mlp(w1,b1,w2,b2,data_test,labels_test):
    a0 = data_test # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
    a1 = 1 / (1 + np.exp(-z1))  # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2  # input of the output layer
    a2 = 1 / (1 + np.exp(-z2))  # output of the output layer (sigmoid activation function)
    predictions = a2  # the predicted values are the outputs of the output layer

    #accuracy for tracking progress
    accuracy = calculate_accuracy(predictions,labels_test)
    return accuracy


def run_mlp_training(data_train, labels_train, data_test, labels_test,d_h,learning_rate,num_epoch,loss):

    _,d_in= data_train.shape
    labels_list = []
    for label in labels_train:
        if not label in labels_list:
            labels_list.append(label)
    d_out = len(labels_list)

    # Random initialization of the network weights and biaises
    w1 = 2 * np.random.rand(d_in, d_h) - 1  # first layer weights
    b1 = np.zeros((1, d_h))  # first layer biaises
    w2 = 2 * np.random.rand(d_h, d_out) - 1  # second layer weights
    b2 = np.zeros((1, d_out))  # second layer biaises

    #train the mlp
    w1,b1,w2,b2,training_accuracies= train_mlp(w1,b1,w2,b2,data_train,labels_train,learning_rate,num_epoch,loss)
    
    #test the mlp
    test_accuracy=test_mlp(w1,b1,w2,b2,data_test,labels_test)

    return training_accuracies,test_accuracy




if __name__=="__main__":

    #décommenter pour utiliser les données non normalisées
    #data, labels=read_cifar("data/cifar-10-batches-py/")
    data, labels=read_cifar_preprocess("data/cifar-10-batches-py/")

    split_factor=0.9
    d_h=64
    learning_rate=0.01
    num_epoch=50
    loss = "mean square root"
    #loss = "cross-entropy"



    data_train,labels_train,data_test,labels_test=split_dataset(data,labels,split_factor)

    #la fonction read_cifar_preprocess fait déja la normalisation et le fait par channel 
    #ces fonction sont à décommenter si on veut utiliser read_cifar et normaliser mais pas par channel

    #data_mean = np.mean(data_train, axis=0)
    #data_std = np.std(data_train, axis=0)
    #data_train = (data_train - data_mean) / (data_std + 1e-8)
    #data_test = (data_test - data_mean) / (data_std + 1e-8)

    training_acc , test_acc = run_mlp_training(data_train, labels_train, data_test, labels_test,d_h,learning_rate,num_epoch,loss)

    plt.plot(training_acc)
    plt.title("training accuracy across learning epoch. test accuracy: "+ str(test_acc))
    plt.show()
