from read_cifar import * from collections import Counter import matplotlib.pyplot as plt # Compute the euclidean distance matrix where the rows are the training data and the columns the testing data # In the dists[i][j] there is the euclidean distance between the i-data_train image and the j-data_test image def distance_matrix(data_train, data_test): train_squared = np.sum(data_train ** 2, axis=1, keepdims=True) test_squared = np.sum(data_test ** 2, axis=1, keepdims=True) dot_product = np.dot(data_train, data_test.T) dists = np.sqrt(train_squared - 2 * dot_product + test_squared.T) #print(dists.shape) return dists def knn_predict(dists, labels_train, k): # we look for the k-images at the minimum distance for each data_test image # and we assign the class with the highest frequency among the k # (I personally prefer having the testing data on the rows) dists=dists.T predictions = [] for distances in dists: min_indexes = np.argpartition(distances, k)[:k] possible_pred = labels_train[min_indexes] counted = Counter(possible_pred) pred = counted.most_common(1)[0][0] predictions.append(pred) return predictions def evaluate_knn(dists, labels_train, labels_test, k): # We apply the knn algorithm and then we compare the prediction swith the labels predictions = knn_predict(dists, labels_train, k) return np.mean(predictions == labels_test) def main(): print('#START#') # Set hyperparameters num_k = 20 # Load CIFAR dataset and split the training data and the labels for the two phases(train and test) folder_path = 'data/cifar-10-batches-py' data, labels = read_cifar(folder_path) data_train, data_test, labels_train, labels_test = split_dataset(data, labels, 0.9) # Computation of the distance matrix once dists = distance_matrix(data_train, data_test) # Test the knn algorithm at the variation of k accuracies=[] for k in range(num_k): accuracy = evaluate_knn(dists, labels_train, labels_test, k+1) print('For k = ' + str(k) +' accuracy : '+ str(round(accuracy, 4))) accuracies.append(accuracy) # Plot the accuracy for each k plt.figure(figsize=(10, 6)) x = range(1, num_k + 1) plt.plot(x, accuracies) plt.xlabel('K') plt.ylabel('Accuracy') plt.title('Accuracy evolution') plt.grid() plt.savefig('results/knn.png') plt.show() if __name__ == "__main__": main()