diff --git a/knn.py b/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..5c77681bd76d7af57541685e7366e375c0a1c04e --- /dev/null +++ b/knn.py @@ -0,0 +1,84 @@ +import numpy as np +import pickle +import os +from read_cifar import * +import matplotlib.pyplot as plt + + +def distance_matrix (M1, M2) : + sum_squares_1 = np.sum(M1**2, axis = 1, keepdims = True) + sum_squares_2 = np.sum(M2**2, axis = 1, keepdims = True) + + dot_product = np.dot(M1, M2.T) + dists = np.sqrt(sum_squares_1 - 2*dot_product + sum_squares_2.T) + + return dists + + +def k_smallest_indexes (liste, k) : + if k <= 0 or k > len(liste) : + return [] + + indexes = list(range(len(liste))) + indexes.sort(key=lambda i: liste[i]) + + k_smallest_indexes = indexes[:k] + + return k_smallest_indexes + + +def knn_predict (data_train, labels_train, data_test, k) : + + dists = distance_matrix(data_train, data_test) + + predicted_labels = [] + + for i in range (len(data_test)) : + distance = dists[i] + labels = [] + k_nearest_neighbors = k_smallest_indexes(distance, k) + for j in k_nearest_neighbors : + labels.append(labels_train[j]) + predicted_label = max(labels, key=labels.count) + predicted_labels.append(predicted_label) + + return predicted_labels + +def evaluate_knn (data_train, labels_train, data_test, labels_test, k) : + predicted_labels = knn_predict(data_train, labels_train, data_test, k) + accuracy = 0 + for i in range (len(predicted_labels)) : + if predicted_labels[i] == labels_test[i] : + accuracy +=1 + accuracy_rate = accuracy/len(predicted_labels) + return accuracy_rate + + + +if __name__ == "__main__": + + K = 50 + split = 0.9 + + batch_dir = 'data/cifar-10-batches-py/' + data, labels = read_cifar(batch_dir) + data_train, labels_train, data_test, labels_test = split_dataset(data, labels, split) + accuracy = evaluate_knn (data_train, labels_train, data_test, labels_test, K) + print(accuracy) + + #k = list(range(20)) + #k = [x+1 for x in k] + #accuracy_vector = [] + + #for i in k : + #accuracy_vector.append(evaluate_knn (data_train, labels_train, data_test, labels_test, i)) + + #plt.plot(k, accuracy_vector) + #plt.show() + + + + + + + \ No newline at end of file