import numpy as np import pickle import os from read_cifar import * import matplotlib.pyplot as plt def distance_matrix (M1, M2) : sum_squares_1 = np.sum(M1**2, axis = 1, keepdims = True) sum_squares_2 = np.sum(M2**2, axis = 1, keepdims = True) dot_product = np.dot(M1, M2.T) dists = np.sqrt(sum_squares_1 - 2*dot_product + sum_squares_2.T) return dists def k_smallest_indexes (liste, k) : if k <= 0 or k > len(liste) : return [] indexes = list(range(len(liste))) indexes.sort(key=lambda i: liste[i]) k_smallest_indexes = indexes[:k] return k_smallest_indexes def knn_predict (data_train, labels_train, data_test, k) : dists = distance_matrix(data_train, data_test) predicted_labels = [] for i in range (len(data_test)) : distance = dists[i] labels = [] k_nearest_neighbors = k_smallest_indexes(distance, k) for j in k_nearest_neighbors : labels.append(labels_train[j]) predicted_label = max(labels, key=labels.count) predicted_labels.append(predicted_label) return predicted_labels def evaluate_knn (data_train, labels_train, data_test, labels_test, k) : predicted_labels = knn_predict(data_train, labels_train, data_test, k) accuracy = 0 for i in range (len(predicted_labels)) : if predicted_labels[i] == labels_test[i] : accuracy +=1 accuracy_rate = accuracy/len(predicted_labels) return accuracy_rate if __name__ == "__main__": K = 50 split = 0.9 batch_dir = 'data/cifar-10-batches-py/' data, labels = read_cifar(batch_dir) data_train, labels_train, data_test, labels_test = split_dataset(data, labels, split) accuracy = evaluate_knn (data_train, labels_train, data_test, labels_test, K) print(accuracy) #k = list(range(20)) #k = [x+1 for x in k] #accuracy_vector = [] #for i in k : #accuracy_vector.append(evaluate_knn (data_train, labels_train, data_test, labels_test, i)) #plt.plot(k, accuracy_vector) #plt.show()