import numpy as np
import pickle
import matplotlib.pyplot as plt
from read_cifar import read_cifar_batch, split_dataset
import time

def distance_matrix(data_train, data_test):
    dist_mat=[]
    for image_test in data_test:
        dist_mat.append([])
        for image_train in data_train:
            dist_mat[-1].append(np.sum(np.square(image_train-image_test)))
    return(np.array(dist_mat))

def knn_predict(dist, labels_train, k):
    resultat=[]
    for image_test in dist:
        k_max = np.argpartition(image_test, k)[:k]
        val, count = np.unique(labels_train[k_max], return_counts=True)
        indexe = np.argmax(count)
        resultat.append(val[indexe])
    return (resultat)

def knn_predict2(dist, labels_train, k):
    resultat=[]
    for im in dist:
        dico={}
        kmax=np.argpartition(im, k)[:k]
        for indexe in kmax:
            if labels_train[indexe] in dico:
                dico[labels_train[indexe]][0]+=1
                dico[labels_train[indexe]][1]+=im[indexe]
            else:
                dico[labels_train[indexe]]=[1,im[indexe]]
        dico = sorted(dico.items(), key=lambda item: item[1][0], reverse=True)
        max_value = dico[0][1][0]
        dico = [item for item in dico if item[1][0] == max_value]
        if len(dico) > 1:
            dico = sorted(dico, key=lambda item: item[1][1])
        resultat.append(dico[0][0])
    return(resultat)

def affichage(d_train, l_train, d_test, l_test):
    long, large = 5,4

    with open("data/cifar-10-batches-py/batches.meta", 'rb') as file:
        batch_data = pickle.load(file, encoding='bytes')
    liste_names= np.array(batch_data[b'label_names'])
    fig, axes = plt.subplots(large, long, figsize=(12, 5))
    fig.subplots_adjust(hspace=0.5)
    for i,_ in enumerate(l_train):
        im = np.array(np.reshape(d_train[i, 0:3072], (32, 32, 3), order='F'), dtype=np.int64)
        im = np.transpose(im, (1, 0, 2))
        name=liste_names[l_train[i]]
        axes[i // long, i % long].imshow(im)
        axes[i // long, i % long].set_title(f"Train : {name.decode('utf-8')}")
        axes[i // long, i % long].axis('off')
    for i,_ in enumerate(l_test):
        im = np.array(np.reshape(d_test[i, 0:3072], (32, 32, 3), order='F'), dtype=np.int64)
        im = np.transpose(im, (1, 0, 2))
        j = i + len(l_train)
        name=liste_names[l_test[i]]
        axes[j // long, j % long].imshow(im)
        axes[j // long, j % long].set_title(f"Test :  {name.decode('utf-8')}")
        axes[j // long, j % long].axis('off')    
    plt.show()

def application(cifar_path, nbr_val, split_coef, nbr_knn):
    d, l = read_cifar_batch(cifar_path)
    d_train, l_train, d_test, l_test = split_dataset(d[:nbr_val,:], l[:nbr_val], split_coef)
    dist_matrice = distance_matrix(d_train, d_test)
    res = knn_predict(dist_matrice, l_train, nbr_knn)
    stat=0
    for i in range(len(l_test)):
        if l_test[i] == res[i]:
            stat += 1
    return(stat/len(l_test))

def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
    dist_matrice = distance_matrix(data_train, data_test)
    res = knn_predict(dist_matrice, labels_train, k)
    return(np.sum(labels_test == res) / len(labels_test))


if __name__ == "__main__":
    #d, l = read_cifar_batch("data/cifar-10-batches-py/data_batch_1")
    nbr_knn = 20
    nbr_val = 10
    x = range(1, nbr_knn)
    d, l = read_cifar_batch("data/cifar-10-batches-py/data_batch_1")
    for essai in range(nbr_val):
        d_train, l_train, d_test, l_test = split_dataset(d, l, 0.9)
        dist_matrice = distance_matrix(d_train, d_test)
        y1 = []
        #y2 = []
        for knn in x:
            stat = 0
            res = knn_predict(dist_matrice, l_train, knn)
            res2 = knn_predict2(dist_matrice, l_train, knn)
            y1.append(np.sum(l_test == res) / len(l_test))
            #y2.append(np.sum(l_test == res2) / len(l_test))
        plt.plot(x, y1, label=f'Précision knn mesure {essai}')
        #plt.plot(x, y2, label='Précision knn2')
    plt.xlabel('nombre de plus proche voisins concidérés')
    plt.ylabel('taux de bonne calification')
    plt.xticks(range(1, nbr_knn))  # Afficher des valeurs entières sur l'axe des abscisses
    plt.yticks(np.arange(0, 1.1, 0.1))
    plt.legend()
    plt.show()