import read_cifar
import numpy as np
import matplotlib.pyplot as plt

def distance_matrix(matrix1, matrix2):
    #X_test then X_train in this order
    sum_of_squares_matrix1 = np.sum(np.square(matrix1), axis=1, keepdims=True) #A^2
    sum_of_squares_matrix2 = np.sum(np.square(matrix2), axis=1, keepdims=True) #B^2

    dot_product = np.dot(matrix1, matrix2.T) # A * B (matrix mutliplication)
    
    dists = np.sqrt(sum_of_squares_matrix1 + sum_of_squares_matrix2.T - 2 * dot_product) # Compute the product
    return dists

def knn_predict(dists, labels_train, k):
    output = []
    # Loop on all the images_test
    for i in range(len(dists)):
        # Innitialize table to store the neighbors
        res = [0] * 10
        # Get the closest neighbors
        labels_close = np.argsort(dists[i])[:k]
        for label in labels_close:
            #add a label to the table of result
            res[labels_train[label]] += 1
        # Get the class with the maximum neighbors
        label_temp = np.argmax(res) #Careful to the logic here, if there is two or more maximum, the function the first maximum encountered
        output.append(label_temp)
    return(np.array(output))

def evaluate_knn(data_train, labels_train, data_test, labels_tests, k):
    dist = distance_matrix(data_test, data_train)
    result_test = knn_predict(dist, labels_train, k)

    #accuracy 
    N = labels_tests.shape[0]
    accuracy = (labels_tests == result_test).sum() / N
    return(accuracy)

def bench_knn():

    k_indices = [i for i in range(20) if i % 2 != 0]
    accuracies = []

    # Load data
    data, labels = read_cifar.read_cifar('image-classification/data/cifar-10-batches-py')
    X_train, X_test, y_train, y_test = read_cifar.split_dataset(data, labels, 0.9)
    #Load one batch
    # data, labels = read_cifar.read_cifar_batch('image-classification/data/cifar-10-batches-py/data_batch_1')
    # X_train, X_test, y_train, y_test = read_cifar.split_dataset(data, labels, 0.9)

    # Loop on the k_indices to get all the accuracies
    for k in k_indices:
        accuracy = evaluate_knn(X_train, y_train, X_test, y_test, k)
        accuracies.append(accuracy)
    
    # Save and show the graph of accuracies
    fig = plt.figure()
    plt.plot(k_indices, accuracies)
    plt.title("Accuracy as function of k")
    plt.show()
    plt.savefig('image-classification/results/knn_batch_1.png')
    plt.close(fig)


if __name__ == "__main__":

    bench_knn()
    # data, labels = read_cifar.read_cifar('image-classification/data/cifar-10-batches-py')
    # X_train, X_test, y_train, y_test = read_cifar.split_dataset(data, labels, 0.9)
    # print(evaluate_knn(X_train, y_train, X_test, y_test, 5))
    # print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)

    # y_test = []
    # x_test = np.array([[1,2],[4,6]])
    # x_train = np.array([[2,4],[7,2],[4,6]])
    # y_train = [1,2,1]
    # dist = distance_matrix(x_test,x_train)