import numpy as np
from read_cifar import read_cifar
from read_cifar import read_cifar_preprocess
from read_cifar import split_dataset
import matplotlib.pyplot as plt
import math as ma


def sigmoid(x):  
    return np.exp(-np.logaddexp(0, -x))


def learn_once_mse(w1,b1,w2,b2,data,targets,learning_rate):
    labels_list = []
    for label in targets:
        if not label in labels_list:
            labels_list.append(label)
    d_out = len(labels_list)

    # Forward pass
    a0 = data 
    z1 = np.matmul(a0, w1) + b1 
    a1 = sigmoid(z1)  
    z2 = np.matmul(a1, w2) + b2  
    a2 = softmax(z2)  
    predictions = a2


    targets_one_hot = one_hot(targets)

    # Compute loss (MSE)
    loss = np.mean(np.square(predictions - targets_one_hot))


    # Backpropagation


    dc_da2 = (2/d_out) * (a2 - targets_one_hot)
    dc_dz2 = dc_da2 * a2 * (1 - a2)
    dc_dw2 = np.dot(a1.T, dc_dz2)
    dc_db2 = np.sum(dc_dz2, axis=0)
    dc_da1 = np.dot(dc_dz2, w2.T)
    dc_dz1 = dc_da1 * a1 * (1 - a1)
    dc_dw1 = np.dot(a0.T, dc_dz1)
    dc_db1 = np.sum(dc_dz1, axis=0)


    w1 -= dc_dw1*learning_rate
    w2 -= dc_dw2*learning_rate
    b1 -= dc_db1*learning_rate
    b2 -= dc_db2*learning_rate

    return w1,b1,w2,b2,loss


def one_hot(array):
    labels_list = []
    for label in array:
        if not label in labels_list:
            labels_list.append(label)
    encoding = np.zeros(array.shape + (len(labels_list),))
    for index,element in np.ndenumerate(array):
        new_one_coord=index+(labels_list.index(element),)
        encoding[new_one_coord] = 1
    return encoding


def calculate_accuracy(predictions,labels):

    correct_counter=0
    for i in range(len(predictions)):
        if labels[i]==np.argmax(predictions[i]):
            correct_counter+=1
    return correct_counter/len(predictions)


def softmax(x):
    exp_x = np.exp(x - np.max(x, axis=1, keepdims=True))
    return exp_x / np.sum(exp_x, axis=1, keepdims=True)


def learn_once_cross_entropy(w1,b1,w2,b2,data,labels_train,learning_rate):
    

    Y= one_hot(labels_train)
    
    # Forward pass
    a0 = data 
    z1 = np.matmul(a0, w1) + b1  
    a1 = sigmoid(z1)
    z2 = np.matmul(a1, w2) + b2  
    a2 = softmax(z2) 
    predictions = a2 

    loss = -np.sum(Y * np.log(predictions + 1e-8)) / predictions.shape[0]

    
    dc_dz2 = (a2 - Y)
    dc_dw2 = np.dot(a1.T, dc_dz2)
    dc_db2 = np.sum(dc_dz2, axis=0)
    dc_da1 = np.dot(dc_dz2, w2.T)
    dc_dz1 = dc_da1 * a1 * (1 - a1)
    dc_dw1 = np.dot(a0.T, dc_dz1)
    dc_db1 = np.sum(dc_dz1, axis=0)

    
    w1 -= dc_dw1*learning_rate
    w2 -= dc_dw2*learning_rate
    b1 -= dc_db1*learning_rate
    b2 -= dc_db2*learning_rate
    



    return w1,b1,w2,b2,loss


def train_mlp(w1,b1,w2,b2,data_train,labels_train,learning_rate,num_epoch):
    train_accuracies = []

    for n in range(num_epoch):

        w1,b1,w2,b2,training_loss=learn_once_cross_entropy( w1,b1,w2,b2,data_train,labels_train,learning_rate)
        acc = test_mlp(w1,b1,w2,b2,data_train,labels_train)
        train_accuracies.append(acc)

        print(f"epoch {n}, training accuracy is : {acc}")

    return w1,b1,w2,b2,train_accuracies


def test_mlp(w1,b1,w2,b2,data_test,labels_test):
    a0 = data_test 
    z1 = np.matmul(a0, w1) + b1  
    a1 = sigmoid(z1)
    z2 = np.matmul(a1, w2) + b2  
    a2 = softmax(z2)
    predictions = a2  


    accuracy = calculate_accuracy(predictions,labels_test)
    return accuracy


def run_mlp_training(data_train, labels_train, data_test, labels_test,d_h,learning_rate,num_epoch):

    _,d_in= data_train.shape
    labels_list = []
    for label in labels_train:
        if not label in labels_list:
            labels_list.append(label)
    d_out = len(labels_list)

    w1 = 2 * np.random.rand(d_in, d_h) - 1 
    w2 = 2 * np.random.rand(d_h, d_out) - 1 
    b2 = np.zeros((1, d_out)) 
    b1 = np.zeros((1, d_h)) 

    w1,b1,w2,b2,training_accuracies= train_mlp(w1,b1,w2,b2,data_train,labels_train,learning_rate,num_epoch)
    
    test_accuracy=test_mlp(w1,b1,w2,b2,data_test,labels_test)

    return training_accuracies,test_accuracy

def run_mlp_training_v2(data_train, labels_train, data_test, labels_test,d_h,learning_rate,num_epoch):
#in this version, we use Xavier initialisation and a vanishing learning rate


    _,d_in= data_train.shape
    labels_list = []
    for label in labels_train:
        if not label in labels_list:
            labels_list.append(label)
    d_out = len(labels_list)

    w1 =  np.random.randn(d_in, d_h)  *np.sqrt(2 / (d_in + d_h)) 
    w2 =  np.random.randn(d_h, d_out) *np.sqrt(2 / (d_h + d_out)) 
    b2 = np.zeros((1, d_out)) 
    b1 = np.zeros((1, d_h)) 

    w1,b1,w2,b2,training_accuracies= train_mlp_v2(w1,b1,w2,b2,data_train,labels_train,learning_rate,num_epoch)
    
    test_accuracy=test_mlp(w1,b1,w2,b2,data_test,labels_test)

    return training_accuracies,test_accuracy

def train_mlp_v2(w1,b1,w2,b2,data_train,labels_train,learning_rate,num_epoch):
    decay_rate =0.1
    train_accuracies = []

    for n in range(num_epoch):

        w1,b1,w2,b2,training_loss=learn_once_cross_entropy( w1,b1,w2,b2,data_train,labels_train,np.exp(-decay_rate * n))
        acc = test_mlp(w1,b1,w2,b2,data_train,labels_train)
        train_accuracies.append(acc)

        print(f"epoch {n}, training accuracy is : {acc}")

    return w1,b1,w2,b2,train_accuracies



if __name__=="__main__":

    #décommenter pour utiliser les données normalisées
    data, labels = read_cifar("data/cifar-10-batches-py/")
    #data, labels = read_cifar_preprocess("data/cifar-10-batches-py/")

    split_factor=0.9
    d_h=64
    learning_rate=0.1
    num_epoch=100

    data_train,labels_train,data_test,labels_test = split_dataset(data,labels,split_factor)

    training_acc , test_acc = run_mlp_training(data_train, labels_train, data_test, labels_test,d_h,learning_rate,num_epoch)
    #training_acc , test_acc = run_mlp_training_v2(data_train, labels_train, data_test, labels_test,d_h,learning_rate,num_epoch)
    plt.plot(training_acc)
    plt.title("training accuracy across learning epoch. final test accuracy is : "+ str(test_acc))
    plt.show()
