import numpy as np
from read_cifar import read_cifar
from read_cifar import split_dataset
import matplotlib.pyplot as plt
def distance_matrix(a,b):

    a_y,a_x = a.shape
    b_y,b_x = b.shape
    a_squared_sum= np.sum(np.multiply(a,a),axis=1)
    b_squared_sum= np.sum(np.multiply(b,b),axis=1)
    a_squared_sum_sprayed=np.dot(a_squared_sum.reshape(len(a_squared_sum),1),np.ones((1,b_y)))
    b_squared_sum_sprayed=np.dot(np.ones((a_y,1)),np.transpose(b_squared_sum).reshape(1,len(b_squared_sum)))
    dists= np.sqrt(np.transpose (a_squared_sum_sprayed+b_squared_sum_sprayed)-2*np.dot(b,np.transpose(a)))
    
    return dists


def k_min_args(list,k):
    sorted_indices=np.argpartition(list,k)
    return sorted_indices[:k]

    
def knn_predict(dists, labels_train , k):
    predicted_labels= []
    for test_dists in dists:
        nearest_neighbors = k_min_args(test_dists,k)

        #a dictionary for counting the labels of the neighbors
        neighbors_labels = {}
        for neighbor in nearest_neighbors:
            neighbors_labels[labels_train[neighbor]]= neighbors_labels.get(labels_train[neighbor],0)+1

        label_arg = np.argmax(neighbors_labels.values)
        most_neighbor_label = list(neighbors_labels.keys())[label_arg]
        predicted_labels.append(most_neighbor_label)
    return predicted_labels

def evaluate_knn(data_train, labels_train,data_test, label_test, k):
    dists = distance_matrix(data_train,data_test)
    predicted_labels = knn_predict(dists,labels_train=labels_train,k=k)
    correct_counter=0
    for i in range(len(label_test)):
        if predicted_labels[i]==label_test[i]:
            correct_counter+=1
    return correct_counter/len(label_test)
    

if __name__=="__main__":
    data, labels=read_cifar("data/cifar-10-batches-py/")
    split_factor = 0.9
    data_train, labels_train,data_test,labels_test=split_dataset(data,labels,split_factor)
    accuracy = []
    ks=range(1,20)
    for k in ks:
        error =evaluate_knn(data_train,labels_train,data_test,labels_test,k)
        accuracy.append(error)
        print(f"the k-nearest neighbors at k = {k} gives an error of {accuracy}")
    plt.plot(accuracy)
    plt.show()

