import numpy as np
import read_cifar
import matplotlib.pyplot as plt


def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def learn_once_mse(w1, b1, w2, b2, data, targets, learning_rate):

    N_out = len(data) #number of training examples
    # Forward pass
    a0 = data # the data are the input of the first layer
    z1 = np.dot(a0, w1) + b1  # input of the hidden layer
    a1 = sigmoid(z1)  # output of the hidden layer (sigmoid activation function)
    z2 = np.dot(a1, w2) + b2  # input of the output layer
    a2 = sigmoid(z2)  # output of the output layer (sigmoid activation function)
    predictions = a2  # the predicted values are the outputs of the output layer

    # Compute loss (MSE)
    loss = np.mean(np.square(predictions - targets))
    print(f'loss: {loss}')
    # print('shape a1', a1.shape)
    # print('shape w1', w1.shape)
    # print('shape b1', b1.shape)

    # print('shape a2', a2.shape)
    # print('shape w2', w2.shape)
    # print('shape b2', b2.shape)
   
    # Backpropagation
    
    # Backpropagation
    delta_a2 = 2 / N_out * (a2 - targets)
    delta_z2 = delta_a2 * (a2 * (1 - a2))  # We divide by the sample size to have an average on the error and avoid big gradient jumps
    delta_w2 = np.dot(a1.T, delta_z2) 
    delta_b2 = np.sum(delta_z2, axis = 0, keepdims = True) 

    delta_a1 = np.dot(delta_z2, w2.T)
    delta_z1 = delta_a1 * (a1 * (1 - a1))
    delta_w1 = np.dot(a0.T, delta_z1) 
    delta_b1 = np.sum(delta_z1, axis = 0, keepdims = True)

    return w1, b1, w2, b2, loss

def one_hot(labels):
    num_classes = int(np.max(labels) + 1) #num_classes = 10
    one_hot_matrix = np.eye(num_classes)[labels]
    return one_hot_matrix

def softmax_stable(x):
    #We use this function to avoid computing big numbers
    return(np.exp(x - np.max(x, axis=1, keepdims=True)) / np.exp(x - np.max(x, axis=1, keepdims=True)).sum())

def cross_entropy_loss(y_pred, y_true_one_hot):
    epsilon = 1e-10
    loss = - np.sum( y_true_one_hot * np.log(y_pred + epsilon) ) / len(y_pred)
    return loss


def learn_once_cross_entropy(w1, b1, w2, b2, data, labels_train, learning_rate):

    N_out = len(data) #number of training examples

    # Forward pass
    a0 = data # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
    a1 = sigmoid(z1)  # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2  # input of the output layer
    a2 = softmax_stable(z2)  # output of the output layer (sigmoid activation function)
    predictions = a2  # the predicted values are the outputs of the output layer
    # print('a0', a0[:2])
    # print('w1', w1[:2])
    # print('z1', z1[:2])
    # print('a1', a1[:2])
    # print('z2', z2[:2])
    # print('a2', a2[:2])

    # Compute loss (cross-entropy loss)
    y_true_one_hot = one_hot(labels_train)
    loss = cross_entropy_loss(predictions, y_true_one_hot)

    # Backpropagation
    delta_z2 = (a2 - y_true_one_hot)  # We divide by the sample size to have an average on the error and avoid big gradient jumps
    delta_w2 = np.dot(a1.T, delta_z2) / N_out
    delta_b2 = np.sum(delta_z2, axis = 0, keepdims = True) / N_out

    delta_a1 = np.dot(delta_z2, w2.T) 
    delta_z1 = delta_a1 * (a1 * (1 - a1)) / N_out
    delta_w1 = np.dot(a0.T, delta_z1) / N_out
    delta_b1 = np.sum(delta_z1, axis = 0, keepdims = True) / N_out

    
    # Update weights and biases 
    w1 -= learning_rate * delta_w1
    b1 -= learning_rate * delta_b1
    w2 -= learning_rate * delta_w2
    b2 -= learning_rate * delta_b2

    return w1, b1, w2, b2, loss


def forward(w1, b1, w2, b2, data):
    # Forward pass
    a0 = data # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
    a1 = sigmoid(z1)  # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2  # input of the output layer
    a2 = softmax_stable(z2)  # output of the output layer (sigmoid activation function)
    predictions = a2  # the predicted values are the outputs of the output layer
    return(predictions)

def train_mlp(w1, b1, w2, b2, data_train, labels_train, learning_rate, num_epoch):
    train_accuracies = []
    for epoch in range(num_epoch):
        w1, b1, w2, b2, loss = learn_once_cross_entropy(w1, b1, w2, b2, data_train, labels_train, learning_rate)

        # Compute accuracy
        predictions = forward(w1, b1, w2, b2, data_train)
        predicted_labels = np.argmax(predictions, axis=1)
        accuracy = np.mean(predicted_labels == labels_train)
        train_accuracies.append(accuracy)

        print(f'Epoch {epoch + 1}/{num_epoch}, Loss: {loss:.3f}, Train Accuracy: {accuracy:.5f}')

    return w1, b1, w2, b2, train_accuracies

def test_mlp(w1, b1, w2, b2, data_test, labels_test):
 
    # Compute accuracy
    predictions = forward(w1, b1, w2, b2, data_test)
    predicted_labels = np.argmax(predictions, axis=1)
    test_accuracy = np.mean(predicted_labels == labels_test)
    print(f'Test Accuracy: {test_accuracy:.2f}')
    return test_accuracy

def run_mlp_training(data_train, labels_train, data_test, labels_test, d_h, learning_rate, num_epoch):

    d_in = data_train.shape[1]
    d_out = 10 #we can hard code it here or len(np.unique(label_train))

    #Random initialisation of weights Xavier initialisation
    w1 = np.random.randn(d_in, d_h) / np.sqrt(d_in)
    b1 = np.zeros((1, d_h))
    w2 = np.random.randn(d_h, d_out) / np.sqrt(d_h)
    b2 = np.zeros((1, d_out))

    # Train MLP
    w1, b1, w2, b2, train_accuracies = train_mlp(w1, b1, w2, b2, data_train, labels_train, learning_rate, num_epoch)

    # Test MLP
    test_accuracy = test_mlp(w1, b1, w2, b2, data_test, labels_test)
    return train_accuracies, test_accuracy

def plot_graph(data_train, labels_train, data_test, labels_test, d_h, learning_rate, num_epoch):
    # Run MLP training
    train_accuracies, test_accuracy = run_mlp_training(data_train, labels_train, data_test, labels_test, d_h, learning_rate, num_epoch)
    
    # Plot and save the learning accuracy graph
    plt.figure(figsize=(8, 6))
    epochs = np.arange(1, num_epoch + 1)
    plt.plot(epochs, train_accuracies, marker='x', color='b', label='Train Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('MLP Train Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig('image-classification/results/mlp.png')
    plt.show()



if __name__ == '__main__':
    data, labels = read_cifar.read_cifar('image-classification/data/cifar-10-batches-py')
    X_train, X_test, y_train, y_test = read_cifar.split_dataset(data, labels, 0.9)
    d_in, d_h, d_out = 3072, 64, 10
    learning_rate = 0.1
    num_epoch = 300

    # #Initialisation 
    # w1 = np.random.randn(d_in, d_h) / np.sqrt(d_in)
    # b1 = np.zeros((1, d_h))
    # w2 = np.random.randn(d_h, d_out) / np.sqrt(d_h)
    # b2 = np.zeros((1, d_out))

    # train_mlp(w1, b1, w2, b2, X_train, y_train, 0.1, 100)

    # test_mlp(w1, b1, w2, b2, X_test[:50], y_test[:50])
    plot_graph(X_train, y_train, X_test ,y_test , d_h, learning_rate, num_epoch)