import numpy as np
import pickle
from read_cifar import read_cifar_batch, split_dataset
import matplotlib.pyplot as plt

def learning_methode(k,dk,learning_rate):
    k=k-learning_rate*dk
    return(k)

def learn_once_mse(w1,b1,w2,b2,data,targets,learning_rate):
    # Forward pass
    a0 = data # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
    a1 = 1 / (1 + np.exp(-z1))  # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2  # input of the output layer
    a2 = 1 / (1 + np.exp(-z2))  # output of the output layer (sigmoid activation function)
    predictions = a2  # the predicted values are the outputs of the output layer

    #dc_da2=(2/data.shape[0])*(a2-targets)
    dc_da2=(1/data.shape[0])*((-targets/a2)-(1-targets)/(1-a2))
    dc_dz2=dc_da2*(a2*(1-a2))
    dc_dw2=np.matmul(np.transpose(a1), dc_dz2)
    dc_db2=np.matmul(np.ones((1,dc_dz2.shape[0])),dc_dz2)
    dc_da1=np.matmul(dc_dz2,np.transpose(w2))
    dc_dz1=dc_da1*(a1*(1-a1))
    dc_dw1=np.matmul(np.transpose(a0), dc_dz1)
    dc_db1=np.matmul(np.ones((1,dc_dz1.shape[0])),dc_dz1)

    w1=learning_methode(w1,dc_dw1,learning_rate)
    b1=learning_methode(b1,dc_db1,learning_rate)
    w2=learning_methode(w2,dc_dw2,learning_rate)
    b2=learning_methode(b2,dc_db2,learning_rate)

    # Compute loss (MSE)
    # loss = np.mean(np.square(predictions - targets))
    # binary cross-entropy loss
    loss = np.mean(targets*np.log(predictions)-(1-targets)*np.log(1-predictions))
    return(w1,b1,w2,b2,loss)

def one_hot(label):
    nbr_classe=9
    mat=np.zeros((len(label),nbr_classe))
    for label_indexe,label_im, in enumerate(label):
        mat[label_indexe,label_im-1]=1
    return(mat)

def learn_once_cross_entropy(w1,b1,w2,b2,data,labels_train,learning_rate):
    Y=one_hot(labels_train)
    w1,b1,w2,b2,loss=learn_once_mse(w1,b1,w2,b2,data,Y,learning_rate)
    return(w1,b1,w2,b2,loss)

def train_mlp(w1,b1,w2,b2,d_train,labels_train,learning_rate,num_epoch):
    train_accuracies=[]
    pas=len(labels_train)//num_epoch
    for k in range(num_epoch):
        partial_data=d_train[k*pas:(k+1)*pas,:]
        patial_label=l_train[k*pas:(k+1)*pas]
        w1,b1,w2,b2,loss=learn_once_cross_entropy(w1,b1,w2,b2,partial_data,patial_label,learning_rate)
        train_accuracies.append(loss)
    return (w1,b1,w2,b2,train_accuracies)

def test_mlp(w1,b1,w2,b2,d_test,labels_test):
    a0 = d_test # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
    a1 = 1 / (1 + np.exp(-z1))  # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2  # input of the output layer
    a2 = 1 / (1 + np.exp(-z2))  # output of the output layer (sigmoid activation function)
    predictions = a2  # the predicted values are the outputs of the output layer
    prediction_2 = np.empty(predictions.shape[0], dtype=int)
    for i, ligne in enumerate(predictions):
        prediction_2[i] = np.argmax(ligne)+1
    indices_egalite = np.where(prediction_2 == labels_test)[0]
    nombre_indices = len(indices_egalite)
    return(nombre_indices/len(labels_test))

def run_mlp_training(data_train, labels_train, data_test, labels_test,d_h,learning_rate,num_epoch):
    d_in = data_train.shape[1]  # input dimension
    d_out = max(labels_train)  # output dimension (number of neurons of the output layer)

    w1 = 2 * np.random.rand(d_in, d_h) - 1  # first layer weights
    b1 = np.zeros((1, d_h))  # first layer biaises
    w2 = 2 * np.random.rand(d_h, d_out) - 1  # second layer weights
    b2 = np.zeros((1, d_out))  # second layer biaises

    w1,b1,w2,b2,loss=train_mlp(w1,b1,w2,b2,data_train, labels_train,learning_rate,num_epoch)
    test_accuracy=test_mlp(w1,b1,w2,b2,data_test, labels_test)
    test_accuracy2=unit_test(w1,b1,w2,b2,data_test, labels_test)
    print(test_accuracy,test_accuracy2)
    return(loss,test_accuracy)

def unit_test(w1,b1,w2,b2,data_test, labels_test):
    pos=0
    for indexe,image in enumerate(data_test):
        a0 = [image] # the data are the input of the first layer
        z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
        a1 = 1 / (1 + np.exp(-z1))  # output of the hidden layer (sigmoid activation function)
        z2 = np.matmul(a1, w2) + b2  # input of the output layer
        a2 = 1 / (1 + np.exp(-z2))  # output of the output layer (sigmoid activation function)
        predictions = a2  # the predicted values are the outputs of the output layer
        classe = np.argmax(predictions[0])+1
        
        if classe==labels_test[indexe]:
            pos+=1
    return(pos/len(labels_test))

if __name__ == "__main__":
    d, l = read_cifar_batch("data/cifar-10-batches-py/data_batch_1")
    num_epoch=100
    d_train, l_train, d_test, l_test = split_dataset(d, l, 0.9)

    loss,test_accuracy=run_mlp_training(d_train, l_train, d_test, l_test,64,0.1,num_epoch)
    print(test_accuracy)
    plt.plot(range(num_epoch), loss, label='evolution de la fonction loss par epoque')
    plt.xlabel('epoque')
    plt.ylabel('loss')
    plt.legend()
    plt.show()