import numpy as np
import read_cifar
import matplotlib.pyplot as plt

def distance_matrix(a,b):
    sum_a = np.sum(a**2, axis=1, keepdims=True)
    sum_b = np.sum(b**2, axis=1, keepdims=True)
    dist = np.sqrt(-2 * a.dot(b.T) + sum_a + sum_b)
    return dist



#def knn_predict(dists, labels_train, k):
    #
    # 
def knn_predict(dists, labels_train, k):
    predicted_labels = []
    # For every image in the test set
    for i in range(len(dists)):
        # Initialize an array to store the neighbors
        classes = [0] * 10
        # indexes of the closest neighbors
        indexes_closest_nb = np.argsort(dists[i])[:k]
        for index in indexes_closest_nb:
            #find the labels of the training batch associated with the closest indexes
            classes[labels_train[index]] += 1
        #The class with the highest neighbors is added to the predicted labels
        predicted_labels.append(np.argmax(classes))
    return(np.array(predicted_labels))

def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
    rate = 0
    dist_train_test = distance_matrix(data_train, data_test)
    prediction = knn_predict(dist_train_test, labels_train, k)
    for j in range(len(prediction)):
        if prediction[j]==labels_test[j]:
            rate +=1
    rate = rate/len(prediction)
    return rate

def knn_final():
    range_k = range(1,20)
    rates = []

    data,labels = read_cifar.read_cifar("data/cifar-10-batches-py")
    data_train_f, labels_train_f, data_test_f, labels_test_f = read_cifar.split_dataset(data, labels, 0.9)

    for k in range_k :
        rate_k = evaluate_knn(data_train_f, labels_train_f, data_test_f, labels_test_f, k)
        rates.append(rate_k)

    plt.figure(figsize=(10, 7))
    plt.xlabel('k')
    plt.ylabel('Accuracy rate')
    plt.plot(range_k, rates)
    plt.title("Accuracy rate = f(k)")
    plt.legend()
    plt.grid(True)
    plt.show()





if __name__ == "__main__" :

    knn_final()
    #a1 = np.array([[0,0,1],[0,0,0],[1,1,2]])
    #b1 = np.array([[1,3,1], [1,1,4], [1,5,1]])
    #print(distance_matrix(a1,b1))