import numpy as np
import read_cifar as rc
import matplotlib.pyplot as plt



def learn_once_mse(w1, b1, w2, b2, data, targets, learning_rate):
    # Forward pass
    a0 = data # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
    a1 = 1 / (1 + np.exp(-z1))  # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2  # input of the output layer
    a2 = 1 / (1 + np.exp(-z2))  # output of the output layer (sigmoid activation function)
    predictions = a2  # the predicted values are the outputs of the output layer

    # Compute loss (MSE)
    loss = np.mean(np.square(predictions - targets))
    print(loss)
    

    delta2 = 2 * (predictions - targets) / data.shape[0]  # Derivative of MSE loss
    grad_w2 = np.matmul(a1.T, delta2)
    grad_b2 = np.sum(delta2, axis=0, keepdims=True)
    delta1 = np.matmul(delta2, w2.T) * a1*(1-a1)
    grad_w1 = np.matmul(a0.T, delta1)
    grad_b1 = np.sum(delta1, axis=0, keepdims=True)

    w1 -= learning_rate * grad_w1
    b1 -= learning_rate * grad_b1
    w2 -= learning_rate * grad_w2
    b2 -= learning_rate * grad_b2

    return w1, b1, w2, b2, loss

def one_hot(labels):
    n=np.max(labels)
    return np.eye(n+1)[labels]

def softmax(x):
    exps = np.exp(x)
    return(exps/exps.sum())


def learn_once_cross_entropy(w1, b1, w2, b2, data, labels_train, learning_rate):
    # Forward pass
    a0 = data # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
    a1 = 1 / (1 + np.exp(-z1))  # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2  # input of the output layer
    a2 = softmax(z2)  # output of the output layer (sigmoid activation function)
    predictions = a2  # the predicted values are the outputs of the output layer

    labels_one_hot = one_hot(labels_train)

    # Compute loss (Binary Cross-Entropy)
    loss = np.mean(np.square(predictions - labels_one_hot))

    delta2 = a2 - labels_one_hot
    grad_w2 = np.matmul(a1.T, delta2)
    grad_b2 = np.sum(delta2, axis=0, keepdims=True)
    delta1 = np.matmul(delta2, w2.T) * a1 * (1 - a1)
    grad_w1 = np.matmul(a0.T, delta1)
    grad_b1 = np.sum(delta1, axis=0, keepdims=True)

    w1 -= learning_rate * grad_w1
    b1 -= learning_rate * grad_b1
    w2 -= learning_rate * grad_w2
    b2 -= learning_rate * grad_b2

    return w1, b1, w2, b2, loss

def accuracy(w1, b1, w2, b2, data, labels):
    # Forward pass
    a0 = data
    z1 = np.matmul(a0, w1) + b1
    a1 = 1 / (1 + np.exp(-z1))
    z2 = np.matmul(a1, w2) + b2
    a2 = softmax(z2)
    predictions = a2
    return np.mean(np.argmax(predictions, axis=1) == labels)


def train_mlp(w1, b1, w2, b2, data_train, labels_train, learning_rate, num_epoch):
    train_accuracy = []
    
    for i in range(num_epoch):
        w1, b1, w2, b2, loss = learn_once_cross_entropy(w1, b1, w2, b2, data_train, labels_train, learning_rate)
        train_accuracy.append(accuracy(w1, b1, w2, b2, data_train, labels_train))

    return w1, b1, w2, b2, train_accuracy

def test_mlp(w1, b1, w2, b2, data_test, labels_test):
    # Forward pass
    a0 = data # the data are the input of the first layer
    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
    a1 = 1 / (1 + np.exp(-z1))  # output of the hidden layer (sigmoid activation function)
    z2 = np.matmul(a1, w2) + b2  # input of the output layer
    a2 = softmax(z2)  # output of the output layer (sigmoid activation function)
    predictions = a2  # the predicted values are the outputs of the output layer

    test_accuracy = np.mean(np.argmax(predictions, axis=1) == labels_test)

    return test_accuracy

def run_mlp_training(data_train, labels_train, data_test, labels_test, d_h, learning_rate, num_epoch):
    d_in = data_train.shape[1]
    d_out = np.max(labels_train) + 1

    w1 = 2 * np.random.rand(d_in, d_h) - 1
    b1 = np.zeros((1, d_h))
    w2 = 2 * np.random.rand(d_h, d_out) - 1
    b2 = np.zeros((1, d_out))

    w1, b1, w2, b2, train_losses = train_mlp(w1, b1, w2, b2, data_train, labels_train, learning_rate, num_epoch)
    test_accuracy = test_mlp(w1, b1, w2, b2, data_test, labels_test)

    return train_losses, test_accuracy

if __name__ == "__main__":
    directory="data\cifar-10-batches-py\\"
    split_factor=0.9
    
    d_h = 64
    learning_rate = 0.1
    num_epoch = 100
    
    
    data, labels=rc.read_cifar(directory)
    data_train, labels_train, data_test, labels_test=rc.split_dataset(data, labels, split_factor)

    train_losses, test_accuracy = run_mlp_training(data_train, labels_train, data_test, labels_test, d_h, learning_rate, num_epoch)

    plt.plot(train_losses)
    plt.xlabel('Epoch')
    plt.ylabel('Training Loss')
    plt.title('Training Loss Over Epochs')
    plt.savefig('results/mlp.png')
    plt.show()