from read_cifar import *
from collections import Counter
import matplotlib.pyplot as plt

# Compute the euclidean distance matrix where the rows are the training data and the columns the testing data
# In the dists[i][j] there is the euclidean distance between the i-data_train image and the j-data_test image
def distance_matrix(data_train, data_test):

    train_squared = np.sum(data_train ** 2, axis=1, keepdims=True)
    test_squared = np.sum(data_test ** 2, axis=1, keepdims=True)
    dot_product = np.dot(data_train, data_test.T)
    dists = np.sqrt(train_squared - 2 * dot_product + test_squared.T)
    #print(dists.shape)

    return dists

def knn_predict(dists, labels_train, k):

    # we look for the k-images at the minimum distance for each data_test image
    # and we assign the class with the highest frequency among the k
    # (I personally prefer having the testing data on the rows)
    dists=dists.T
    predictions = []

    for distances in dists:
        min_indexes = np.argpartition(distances, k)[:k]
        possible_pred = labels_train[min_indexes]
        counted = Counter(possible_pred)
        pred = counted.most_common(1)[0][0]
        predictions.append(pred)

    return predictions

def evaluate_knn(dists, labels_train, labels_test, k):

    # We apply the knn algorithm and then we compare the prediction swith the labels
    predictions = knn_predict(dists, labels_train, k)

    return np.mean(predictions == labels_test)

def main():

    print('#START#')

    # Set hyperparameters
    num_k = 20

    # Load CIFAR dataset and split the training data and the labels for the two phases(train and test)
    folder_path = 'data/cifar-10-batches-py'
    data, labels = read_cifar(folder_path)

    data_train, data_test, labels_train, labels_test = split_dataset(data, labels, 0.9)

    # Computation of the distance matrix once
    dists = distance_matrix(data_train, data_test)

    # Test the knn algorithm at the variation of k
    accuracies=[]
    for k in range(num_k):
        accuracy = evaluate_knn(dists, labels_train, labels_test, k+1)
        print('For k = ' + str(k) +' accuracy : '+ str(round(accuracy, 4)))
        accuracies.append(accuracy)

    # Plot the accuracy for each k
    plt.figure(figsize=(10, 6))
    x = range(1, num_k + 1)
    plt.plot(x, accuracies)
    plt.xlabel('K')
    plt.ylabel('Accuracy')
    plt.title('Accuracy evolution')
    plt.grid()
    plt.savefig('results/knn.png')
    plt.show()

if __name__ == "__main__":
    main()