import numpy as np


def distance_matrix(mat1, mat2):
    square1 = np.sum(np.square(mat1), axis = 1)
    square2 = np.sum(np.square(mat2), axis = 1)
    prod = np.dot(mat1, mat2.T)
    dists = np.sqrt(square1 + square2 - 2 * prod)
    return dists

def knn_predict(dists, labels_train, k):
    # results matrix initialisation
    predicted_labels = np.zeros(len(dists))
    # loop on all the test images
    for i in range(len(dists)):
        # sort and keep the k shortest dists for test image i
        sorted_dists = np.argsort(dists[i])
        k_sorted_dists = sorted_dists[:k]
        # get the matching labels_train
        closest_labels = labels_train[k_sorted_dists]
        # get the most common labels_train
        predicted_labels[i] = np.argmax(closest_labels)
    return np.array(predicted_labels)

def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
    dists = distance_matrix(data_test, data_train)
    tot = len(data_test)
    accurate = 0
    predicted_labels = knn_predict(dists, labels_train, k)
    for i in range(tot):
        if predicted_labels[i] == labels_test[i]:
            accurate += 1
    accuracy = accurate/tot
    return accuracy








if __name__ == "__main__":

    bench_knn()
    # data, labels = read_cifar.read_cifar('image-classification/data/cifar-10-batches-py')
    # X_train, X_test, y_train, y_test = read_cifar.split_dataset(data, labels, 0.9)
    # print(evaluate_knn(X_train, y_train, X_test, y_test, 5))
    # print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)

    # y_test = []
    # x_test = np.array([[1,2],[4,6]])
    # x_train = np.array([[2,4],[7,2],[4,6]])
    # y_train = [1,2,1]
    # dist = distance_matrix(x_test,x_train)