## Imports
import numpy as np


## QUESTION 10
def learn_once_mse(w1, b1, w2, b2, data, targets, learning_rate):
    """Perform one gradient descent step of the neural network using the MSE cost.
    Args:
        w1: A np.float32 array of shape d_in x d_h, the first layer weights.
        b1: A np.float32 array of shape 1 x d_h, the first layer biaises.
        w2: A np.float32 array of shape d_h x d_out, the second layer weights.
        b2: A np.float32 array of shape 1 x d_out, the second layer biaises.
        data: A np.float32 array of shape batch_size x d_in, the input data.
        targets: A np.float32 array of shape batch_size x d_out, the targets.
        learning_rate: The learning rate.
    Returns:
        w1: A np.float32 array of shape d_in x d_h, the updated first layer weights.
        b1: A np.float32 array of shape 1 x d_h, the updated first layer biaises.
        w2: A np.float32 array of shape d_h x d_out, the updated second layer weights.
        b2: A np.float32 array of shape 1 x d_out, the updated second layer biaises.
        loss: The cost of the network on the given data.
    """
    # Check shapes
    assert w1.shape[0] == data.shape[1] # d_in
    assert w1.shape[1] == b1.shape[1] # d_h
    assert b1.shape[0] == 1
    assert w2.shape[0] == b1.shape[1] # d_h
    assert w2.shape[1] == b2.shape[1] # d_out
    assert b2.shape[0] == 1
    assert data.shape[0] == targets.shape[0] # batch_size

    N = data.shape[0] # batch_size
    d_in = data.shape[1] # Number of input neurons
    d_h = w1.shape[1] # Number of hidden neurons
    d_out = w2.shape[1] # Number of output neurons

    # Forward pass
    a0 = data # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1 # input of the hidden layer
    a1 = 1 / (1 + np.exp(-z1)) # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2 # input of the output layer
    a2 = 1 / (1 + np.exp(-z2)) # output of the output layer (sigmoid activation function)
    predictions = a2 # shape batch_size x d_out

    # Compute loss (MSE)
    loss = np.mean(np.square(predictions - targets)) # scalar

    # Backward pass
    # Compute gradients

    ## QUESTION 2
    dcost_da2 = 2 * (predictions - targets) / d_out # shape batch_size x d_out
    # print("dcost_da2.shape = ", dcost_da2.shape)

    ## QUESTION 3
    da2_dz2 = a2 * (1 - a2) # shape batch_size x d_out
    # print("da2_dz2.shape = ", da2_dz2.shape)
    dcost_dz2 = dcost_da2 * da2_dz2 # shape batch_size x d_out
    # print("dcost_dz2.shape = ", dcost_dz2.shape)

    ## QUESTION 4
    dz2_dw2 = np.transpose(a1) #shape d_h x batch_size
    # print("dz2_dw2.shape = ", dz2_dw2.shape)
    dcost_dw2 = np.matmul(dz2_dw2, dcost_dz2) # shape d_h x d_out for batch_size = 1
    # print("dcost_dw2.shape = ", dcost_dw2.shape)

    ## QUESTION 5
    dz2_db2 = np.ones((d_out)) # shape d_out
    # print("dz2_db2.shape = ", dz2_db2.shape)
    dcost_db2 = dcost_dz2 * dz2_db2 # shape batch_size x d_out
    # print("dcost_db2.shape = ", dcost_db2.shape)

    ## QUESTION 6
    dz2_da1 = np.transpose(w2) # shape d_h x d_out
    # print("dz2_da1.shape = ", dz2_da1.shape)
    dcost_da1 = np.matmul(dcost_dz2, dz2_da1) # shape batch_size x d_h
    # print("dcost_da1.shape = ", dcost_da1.shape)

    ## QUESTION 7
    da1_dz1 = a1 * (1 - a1) # shape batch_size x d_h
    # print("da1_dz1.shape = ", da1_dz1.shape)
    dcost_dz1 = dcost_da1 * da1_dz1 # shape batch_size x d_h
    # print("dcost_dz1.shape = ", dcost_dz1.shape)

    ## QUESTION 8
    dz1_dw1 = np.transpose(a0) # shape batch_size x d_in
    # print("dz1_dw1.shape = ", dz1_dw1.shape)
    dcost_dw1 = np.matmul(dz1_dw1, dcost_dz1) # shape d_in x d_h
    # print("dcost_dw1.shape = ", dcost_dw1.shape)

    ## QUESTION 9
    dz1_db1 = np.ones((d_h)) # shape d_h
    # print("dz1_db1.shape = ", dz1_db1.shape)
    dcost_db1 = dcost_dz1 * dz1_db1 # shape batch_size x d_h    
    # print("dcost_db1.shape = ", dcost_db1.shape)

    # Update weights and biaises
    w1 = w1 - learning_rate * dcost_dw1
    b1 = b1 - learning_rate * dcost_db1
    w2 = w2 - learning_rate * dcost_dw2
    b2 = b2 - learning_rate * dcost_db2

    return w1, b1, w2, b2, loss


## QUESTION 11
def one_hot(labels):
    """Convert a vector of labels to a one-hot matrix, taking a (n)-D array as parameters and returning the corresponding (n+1)-D one-hot matrix.
    Args:
        labels: A np.int64 array of shape batch_size, the labels.
    Returns:
        b: A np.int64 array of shape batch_size x (labels.max() + 1), the one-hot matrix.
    """
    b = np.zeros((labels.size, labels.max() + 1))
    b[np.arange(labels.size), labels] = 1
    return b


## QUESTION 12
def learn_once_cross_entropy(w1, b1, w2, b2, data, labels_train, learning_rate):
    """Perform one gradient descent step of the neural network using a binary cross-entropy loss.
    The last activation layer of the network is a softmax layer.
    Args:
        w1: A np.float32 array of shape d_in x d_h, the first layer weights.
        b1: A np.float32 array of shape 1 x d_h, the first layer biaises.
        w2: A np.float32 array of shape d_h x d_out, the second layer weights.
        b2: A np.float32 array of shape 1 x d_out, the second layer biaises.
        data: A np.float32 array of shape batch_size x d_in, the input data.
        labels_train: A np.int64 array of shape batch_size, the labels of the training set.
        learning_rate: The learning rate.
    Returns:
        w1: A np.float32 array of shape d_in x d_h, the updated first layer weights.
        b1: A np.float32 array of shape 1 x d_h, the updated first layer biaises.
        w2: A np.float32 array of shape d_h x d_out, the updated second layer weights.
        b2: A np.float32 array of shape 1 x d_out, the updated second layer biaises.
        loss: The loss of the network on the given data.
    """
    N = data.shape[0] # batch_size
    d_h = w1.shape[1] # Number of hidden neurons
    d_out = labels_train.max() + 1 # Number of output neurons

    # Forward pass
    a0 = data # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1 # input of the hidden layer
    z1 = np.clip(z1, -1000, 1000)
    a1 = 1 / (1 + np.exp(-z1)) # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2 # input of the output layer
    a2 = np.exp(z2) / np.sum(np.exp(z2), axis=1, keepdims=True) # output of the output layer (softmax activation function)
    predictions = a2 # shape batch_size x d_out
    # print("predictions.shape = ", predictions.shape)

    # Compute loss (cross-entropy)
    labels_train_one_hot = one_hot(labels_train) # shape batch_size x d_out
    loss = - np.mean(labels_train_one_hot * np.log(predictions) + (1 - labels_train_one_hot) * np.log(1 - predictions)) # scalar

    # Backward pass
    # Compute gradients
    # print("labels_train.shape =", labels_train.shape)
    # dloss_dz2 = predictions - labels_train # shape batch_size x d_out
    dloss_dz2 = predictions - labels_train_one_hot # shape batch_size x d_out
    # print("dloss_dz2.shape = ", dloss_dz2.shape)

    dz2_dw2 = np.transpose(a1) # shape d_h x batch_size
    # print("dz2_dw2.shape = ", dz2_dw2.shape)
    dloss_dw2 = np.matmul(dz2_dw2, dloss_dz2) # shape d_h x d_out
    # print("dloss_dw2.shape = ", dloss_dw2.shape)

    dz2_db2 = np.ones((d_out)) # shape 1 x d_out
    # print("dz2_db2.shape = ", dz2_db2.shape)
    dloss_db2 = np.sum(dloss_dz2 * dz2_db2, axis=0, keepdims=True) # 1 x d_out
    # print("dloss_db2.shape = ", dloss_db2.shape)

    dz2_da1 = np.transpose(w2) # shape d_out x d_h
    # print("dz2_da1.shape = ", dz2_da1.shape)
    dloss_da1 = np.matmul(dloss_dz2, dz2_da1) # shape batch_size x d_h
    # print("dloss_da1.shape = ", dloss_da1.shape)

    da1_dz1 = a1 * (1 - a1) # shape batch_size x d_h
    # print("da1_dz1.shape = ", da1_dz1.shape)
    dloss_dz1 = dloss_da1 * da1_dz1 # shape batch_size x d_h
    # print("dloss_dz1.shape = ", dloss_dz1.shape)

    dz1_dw1 = np.transpose(a0) # shape d_in x batch_size
    # print("dz1_dw1.shape = ", dz1_dw1.shape)
    dloss_dw1 = np.matmul(dz1_dw1, dloss_dz1) # shape d_in x d_h
    # print("dloss_dw1.shape = ", dloss_dw1.shape)

    dz1_db1 = np.ones((d_h)) # shape 1 x d_h
    # print("dz1_db1.shape = ", dz1_db1.shape)
    dloss_db1 = np.sum(dloss_dz1 * dz1_db1, axis=0, keepdims=True) # 1 x d_h
    # print("dloss_db1.shape = ", dloss_db1.shape)

    # Update weights and biaises
    w1 = w1 - learning_rate * dloss_dw1
    b1 = b1 - learning_rate * dloss_db1
    w2 = w2 - learning_rate * dloss_dw2
    b2 = b2 - learning_rate * dloss_db2
    return w1, b1, w2, b2, loss


## QUESTION 13
def train_mlp(w1, b1, w2, b2, data_train, labels_train, learning_rate, num_epoch):
    """Perform num_epoch of training steps of the neural network using the binary cross_entropy loss.
    Args:
        w1: A np.float32 array of shape d_in x d_h, the first layer weights.
        b1: A np.float32 array of shape 1 x d_h, the first layer biaises.
        w2: A np.float32 array of shape d_h x d_out, the second layer weights.
        b2: A np.float32 array of shape 1 x d_out, the second layer biaises.
        data_train: A np.float32 array of shape batch_size x d_in, the training set.
        labels_train: A np.int64 array of shape batch_size, the labels of the training set.
        learning_rate: The learning rate.
        num_epoch: The number of training epochs.
    Returns:
        w1: A np.float32 array of shape d_in x d_h, the updated first layer weights.
        b1: A np.float32 array of shape 1 x d_h, the updated first layer biaises.
        w2: A np.float32 array of shape d_h x d_out, the updated second layer weights.
        b2: A np.float32 array of shape 1 x d_out, the updated second layer biaises.
        train_accuracies: A list of the training accuracies across epochs as a list of floats.
    """
    #print("data_train.shape = ", data_train.shape)
    #print("labels_train.shape = ", labels_train.shape)
    #print("w1.shape = ", w1.shape)
    #print("b1.shape = ", b1.shape)
    #print("w2.shape = ", w2.shape)
    #print("b2.shape = ", b2.shape)
    train_accuracies = []
    losses = []
    for epoch in range(num_epoch):
        w1, b1, w2, b2, loss = learn_once_cross_entropy(w1, b1, w2, b2, data_train, labels_train, learning_rate) 
        # Compute accuracy
        # Forward pass
        a0 = data_train # the data are the input of the first layer
        z1 = np.matmul(a0, w1) + b1 # input of the hidden layer
        z1 = np.clip(z1, -1000, 1000)
        a1 = 1 / (1 + np.exp(-z1)) # output of the hidden layer (sigmoid activation function)
        z2 = np.matmul(a1, w2) + b2 # input of the output layer
        a2 = np.exp(z2) / np.sum(np.exp(z2), axis=1, keepdims=True) # output of the output layer (softmax activation function)
        predictions = a2 # shape batch_size x d_out
        y_pred = np.argmax(predictions, axis=1)
        train_accuracy = np.mean(y_pred == labels_train)
        train_accuracies.append(train_accuracy)
        losses.append(loss)
        if epoch % 10 == 0:
            print(f'train_accuracy à l epoch {epoch}: {train_accuracy}')

        # print("Epoch %d, loss = %f, train accuracy = %f" % (epoch, loss, train_accuracy))
    return w1, b1, w2, b2, train_accuracies, losses


## QUESTION 14
def test_mlp(w1, b1, w2, b2, data_test, labels_test):
    """Test the neural network on the given test set.
    Args:
        w1: A np.float32 array of shape d_in x d_h, the first layer weights.
        b1: A np.float32 array of shape 1 x d_h, the first layer biaises.
        w2: A np.float32 array of shape d_h x d_out, the second layer weights.
        b2: A np.float32 array of shape 1 x d_out, the second layer biaises.
        data_test: A np.float32 array of shape batch_size x d_in, the test set.
        labels_test: A np.int64 array of shape batch_size, the labels of the test set.
    Returns:
        test_accuracy: The accuracy of the network on the given test set.
    """
    #print("data_test.shape = ", data_test.shape)
    #print("labels_test.shape = ", labels_test.shape)
    #print("w1.shape = ", w1.shape)
    #print("b1.shape = ", b1.shape)
    #print("w2.shape = ", w2.shape)
    #print("b2.shape = ", b2.shape)
    # Forward pass
    a0 = data_test # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1 # input of the hidden layer
    z1 = np.clip(z1, -1000, 1000)
    a1 = 1 / (1 + np.exp(-z1)) # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2 # input of the output layer
    a2 = 1 / (1 + np.exp(-z2)) # output of the output layer (sigmoid activation function)
    predictions = a2
    # Compute accuracy
    y_pred = np.argmax(predictions, axis=1)
    test_accuracy = np.mean(y_pred == labels_test)
    return test_accuracy


## QUESTION 15
def run_mlp_training(data_train, labels_train, data_test, labels_test, d_h, learning_rate, num_epoch):
    """Train an MLP classifier and return the trainig accuracies across epochs as a list of floats and the final testing accuracy as a float.
    Args:
        data_train: A np.float32 array of shape batch_size x d_in, the training set.
        labels_train: A np.int64 array of shape batch_size, the labels of the training set.
        data_test: A np.float32 array of shape batch_size x d_in, the test set.
        labels_test: A np.int64 array of shape batch_size, the labels of the test set.
        d_h: The number of neurons in the hidden layer.
        learning_rate: The learning rate.
        num_epoch: The number of training epochs.
    Returns:
        train_accuracies: A list of the training accuracies across epochs as a list of floats.
        test_accuracy: The accuracy of the network on the given test set.
    """
    # Random initialization of the network weights and biaises
    d_in = data_train.shape[1]  # input dimension
    d_out = labels_train.max() + 1  # output dimension (number of neurons of the output layer)
    w1 = 2 * np.random.rand(d_in, d_h) - 1  # first layer weights
    b1 = np.zeros((1, d_h))  # first layer biaises
    w2 = 2 * np.random.rand(d_h, d_out) - 1  # second layer weights
    b2 = np.zeros((1, d_out))  # second layer biaises
    # Train the network
    w1, b1, w2, b2, train_accuracies, losses = train_mlp(w1, b1, w2, b2, data_train, labels_train, learning_rate, num_epoch)
    # Test the network
    test_accuracy = test_mlp(w1, b1, w2, b2, data_test, labels_test)
    return train_accuracies, test_accuracy, losses


if __name__ == "__main__":
    # Define input data
    w1 = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) # d_in = 3, d_h = 2
    b1 = np.array([[0.1, 0.2]]) # d_h = 2
    w2 = np.array([[0.1, 0.2, 0.3, 0.4, 0.5], [0.4, 0.5, 0.6, 0.7, 0.8]]) # d_h = 2, d_out = 5
    b2 = np.array([[0.1, 0.2, 0.3, 0.4, 0.5]]) # d_out = 5
    data = np.array([[0.1, 0.2, 0.3]]) # batch_size = 1, d_in = 3
    targets = np.array([[0.1, 0.2, 0.3, 0.4, 0.5]]) # batch_size = 1, d_out = 5
    learning_rate = 0.1

    # Call function
    w1, b1, w2, b2, cost = learn_once_mse(w1, b1, w2, b2, data, targets, learning_rate)

    # Check output shapes
    assert w1.shape == (3, 2)
    assert b1.shape == (1, 2)
    assert w2.shape == (2, 5)
    assert b2.shape == (1, 5)
    assert cost.shape == ()

    # Test one_hot
    labels = np.array([0, 4, 2, 3])
    print(one_hot(labels))

    # Define input data
    w1 = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) # d_in = 3, d_h = 2
    b1 = np.array([[0.1, 0.2]]) # d_h = 2
    w2 = np.array([[0.1, 0.2, 0.3, 0.4, 0.5], [0.4, 0.5, 0.6, 0.7, 0.8]]) # d_h = 2, d_out = 5
    b2 = np.array([[0.1, 0.2, 0.3, 0.4, 0.5]]) # d_out = 5
    data = np.array([[0.1, 0.2, 0.3]]) # batch_size = 1, d_in = 3
    labels_train = np.array([4]) # batch_size = 1
    learning_rate = 0.1

    # Test learn_once_cross_entropy
    w1, b1, w2, b2, loss = learn_once_cross_entropy(w1, b1, w2, b2, data, labels_train, learning_rate)
    print(w1, b1, w2, b2, loss)

    # Test train_mlp
    w1, b1, w2, b2, train_accuracies, losses = train_mlp(w1, b1, w2, b2, data, labels_train, learning_rate, 10)
    print(train_accuracies)

    # Test test_mlp
    w1, b1, w2, b2, train_accuracies, losses = train_mlp(w1, b1, w2, b2, data, labels_train, learning_rate, 10)
    print(train_accuracies)

    # Test run_mlp_training
    train_accuracies, test_accuracy, losses = run_mlp_training(data, labels_train, data, labels_train, 2, 0.1, 10)
    print(train_accuracies, test_accuracy)

    import read_cifar as rc
    import matplotlib.pyplot as plt
    data, labels = rc.read_cifar(r"data\cifar-10-batches-py")
    split_factor = 0.9
    data_train, labels_train, data_test, labels_test = rc.split_dataset(data, labels, split_factor)
    d_h = 64
    learning_rate = 0.1
    num_epoch = 100
    train_accuracies, test_accuracy, losses = run_mlp_training(data_train, labels_train, data_test, labels_test, d_h, learning_rate, num_epoch)
    print("ok 3")
    plt.plot(range(num_epoch), train_accuracies)
    plt.xlabel("epoch")
    plt.ylabel("accuracy")
    plt.savefig(r"results\mlp.png")
    plt.show()

    plt.plot(range(num_epoch), losses)
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.show()