# -*- coding: utf-8 -*-
"""
Created on Fri Oct 20 17:39:37 2023

@author: oscar
"""
import read_cifar
import numpy as np
import statistics
from statistics import mode
import time
import matplotlib.pyplot as plt
from tqdm import tqdm

def distance_matrix(A,B) : 
    sum_of_squaresA= np.sum(A**2, axis = 1, keepdims = True)
    sum_of_squaresB = np.sum(B**2, axis = 1)
    # sum_of_squaresA = np.tile(sum_of_squaresAVect, (np.shape(B)[0], 1))
    # sum_of_squaresB = np.tile(sum_of_squaresBVect, (np.shape(A)[0], 1))

    # Calculate the dot product between the two matrices
    dot_product = np.dot(A, B.T)
    # dot_product = np.einsum('ij,jk', A, B.T)
    # Calculate the Euclidean distance matrix using the hint provided
    dists = np.sqrt(sum_of_squaresA + sum_of_squaresB - 2 * dot_product)
    return dists

def knn_predict(dists, labels_train, k) : 
    number_train, number_test = np.shape(dists)
    
    # initialze the predicted labels to zeros
    labels_predicted = np.zeros(number_test)
    
    for j in range(number_test) : 
        sorted_indices = np.argsort(dists[:, j])
        knn_indices = sorted_indices[ : k]
        knn_labels = labels_train[knn_indices]
        label_predicted = mode(knn_labels)
        labels_predicted[j] = label_predicted
    return labels_predicted

def evaluate_knn(data_train, labels_train, data_test, labels_test, k) :
    dists = distance_matrix(data_train, data_test)
    labels_predicted = knn_predict(dists, labels_train, k)
    number_true_prediction = np.sum(labels_test == labels_predicted)
    number_total_prediction = len(labels_test)
    classification_rate = number_true_prediction/number_total_prediction
    print(classification_rate)
    return classification_rate   
    
if __name__ == "__main__" :
    file = "./data/cifar-10-python/"
    data, labels = read_cifar.read_cifar(file)
    data_train, labels_train, data_test, labels_test = read_cifar.split_dataset(data, labels, 0.9)
    
    k = 8
    evaluations = []
    for k in tqdm(range(1, k)) :
        evaluations.append(evaluate_knn(data_train, labels_train, data_test, labels_test, k))
    
    fig=plt.figure()
    plt.title("Prediction accuracy as a function of k")
    plt.xlabel("k-nearest neighbors")
    plt.ylabel("Accuracy (%)")
    plt.plot(evaluations)
    plt.show()
    plt.savefig('results/knn.png')