import numpy as np
from read_cifar import read_cifar, split_dataset
import matplotlib.pyplot as plt

def distance_matrix(matrix_a: np.ndarray, matrix_b: np.ndarray):
    sum_squares_1 = np.sum(matrix_a**2, axis = 1, keepdims = True)
    sum_squares_2 = np.sum(matrix_b**2, axis = 1, keepdims = True)

    dot_product = np.dot(matrix_a, matrix_b.T)
    dists = np.sqrt(sum_squares_1 - 2*dot_product + sum_squares_2.T)

    return dists

def knn_predict(dists: np.ndarray, labels_train: np.ndarray, k:int):
    labels_predicts = np.zeros(np.size(dists, 0))
    for i in range(np.size(labels_predicts, 0)):
        #On extrait les indices des k valeurs plus petites (des k plus proches voisins)
        k_neighbors_index = np.argmin(dists[i, :], np.sort(dists[i, :])[:k])
        #On compte la classe la plus présente parmi les k voisins les plus proches
        labels_k_neighbors = labels_train[k_neighbors_index]
        #On compte le nombre d'occurence des classes parmis les k
        _, count = np.unique(labels_k_neighbors, return_counts=True)
        #On associe à la prédiction la classe la plus presente parmis les k
        labels_predicts[i] = labels_k_neighbors[np.argmax(count)]
    return labels_predicts

def evaluate_knn(data_train:np.ndarray, labels_train: np.ndarray, data_test:np.ndarray, labels_test:np.ndarray, k:int):
    dists = distance_matrix(data_test, data_train)
    labels_predicts = knn_predict(dists, labels_train, k)
    #calcul de l'accuracy
    accuracy = 0
    for i in range(np.size(labels_predicts, 0)):
        if abs(labels_predicts[i]-labels_test[i])<10**(-7):
            accuracy += 1
    accuracy /= np.size(labels_predicts, 0)
    return accuracy

def plot_knn(data_train:np.ndarray, labels_train: np.ndarray, data_test:np.ndarray, labels_test:np.ndarray, n: int):
    accuracy_vector = np.zeros(n)
    for k in range(1, n+1):
        accuracy_vector[k] = evaluate_knn(data_train, labels_train, data_test, labels_test)
    plt.plot(accuracy_vector)
    plt.show()
    return





if __name__ == "__main__":
    data, labels = read_cifar()
    data_train, labels_train, data_test, labels_test = split_dataset(data, labels, 0.8)
    k = 5 #Nombre de voisins
    accuracy = evaluate_knn(data_train, labels_train, data_test, labels_test, k)