import numpy as np
import pickle
import os
from read_cifar import *
import matplotlib.pyplot as plt


def distance_matrix (M1, M2) :
    sum_squares_1 = np.sum(M1**2, axis = 1, keepdims = True)
    sum_squares_2 = np.sum(M2**2, axis = 1, keepdims = True)

    dot_product = np.dot(M1, M2.T)
    dists = np.sqrt(sum_squares_1 - 2*dot_product + sum_squares_2.T)

    return dists


def k_smallest_indexes (liste, k) :
    if k <= 0 or k > len(liste) : 
        return []
    
    indexes = list(range(len(liste)))
    indexes.sort(key=lambda i: liste[i])

    k_smallest_indexes = indexes[:k]

    return k_smallest_indexes


def knn_predict (data_train, labels_train, data_test, k) :
    
    dists = distance_matrix(data_train, data_test)

    predicted_labels = []

    for i in range (len(data_test)) :
        distance = dists[i]
        labels = []
        k_nearest_neighbors = k_smallest_indexes(distance, k)
        for j in k_nearest_neighbors :
            labels.append(labels_train[j])
        predicted_label = max(labels, key=labels.count)
        predicted_labels.append(predicted_label)

    return predicted_labels

def evaluate_knn (data_train, labels_train, data_test, labels_test, k) :
    predicted_labels = knn_predict(data_train, labels_train, data_test, k)
    accuracy = 0
    for i in range (len(predicted_labels)) :
        if predicted_labels[i] == labels_test[i] :
            accuracy +=1 
    accuracy_rate = accuracy/len(predicted_labels)
    return accuracy_rate



if __name__ == "__main__":
    
    K = 50
    split = 0.9

    batch_dir = 'data/cifar-10-batches-py/'
    data, labels = read_cifar(batch_dir)
    data_train, labels_train, data_test, labels_test = split_dataset(data, labels, split)
    accuracy = evaluate_knn (data_train, labels_train, data_test, labels_test, K)
    print(accuracy)

    #k = list(range(20))
    #k = [x+1 for x in k]
    #accuracy_vector = []

    #for i in k :
        #accuracy_vector.append(evaluate_knn (data_train, labels_train, data_test, labels_test, i))

    #plt.plot(k, accuracy_vector)
    #plt.show()