Skip to content
Snippets Groups Projects
Commit 72bb7ee1 authored by Audard Lucile's avatar Audard Lucile
Browse files

Update knn.py

parent 5d1fa2d1
Branches
No related tags found
No related merge requests found
import numpy as np import numpy as np
from read_cifar import *
import matplotlib.pyplot as plt
def distance_matrix(mat1, mat2): def distance_matrix(mat1, mat2):
square1 = np.sum(np.square(mat1), axis = 1) # A^2 and B^2
square2 = np.sum(np.square(mat2), axis = 1) square1 = np.sum(np.square(mat1), axis = 1, keepdims=True)
square2 = np.sum(np.square(mat2), axis = 1, keepdims=True)
# A*B
prod = np.dot(mat1, mat2.T) prod = np.dot(mat1, mat2.T)
dists = np.sqrt(square1 + square2 - 2 * prod) # A^2 + B^2 -2*A*B
dists = np.sqrt(square1 + square2.T - 2 * prod)
return dists return dists
def knn_predict(dists, labels_train, k): def knn_predict(dists, labels_train, k):
# results matrix initialisation # results matrix initialization
predicted_labels = np.zeros(len(dists)) predicted_labels = np.zeros(len(dists))
# loop on all the test images # loop on all the test images
for i in range(len(dists)): for i in range(len(dists)):
...@@ -19,17 +24,21 @@ def knn_predict(dists, labels_train, k): ...@@ -19,17 +24,21 @@ def knn_predict(dists, labels_train, k):
# get the matching labels_train # get the matching labels_train
closest_labels = labels_train[k_sorted_dists] closest_labels = labels_train[k_sorted_dists]
# get the most common labels_train # get the most common labels_train
predicted_labels[i] = np.argmax(closest_labels) uniques, counts = np.unique(closest_labels, return_counts = True)
predicted_labels[i] = uniques[np.argmax(counts)]
return np.array(predicted_labels) return np.array(predicted_labels)
def evaluate_knn(data_train, labels_train, data_test, labels_test, k): def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
dists = distance_matrix(data_test, data_train) dists = distance_matrix(data_test, data_train)
# Determine the number of images in data_test
tot = len(data_test) tot = len(data_test)
accurate = 0 accurate = 0
predicted_labels = knn_predict(dists, labels_train, k) predicted_labels = knn_predict(dists, labels_train, k)
# Count the number of images in data_test whose label has been estimated correctly
for i in range(tot): for i in range(tot):
if predicted_labels[i] == labels_test[i]: if predicted_labels[i] == labels_test[i]:
accurate += 1 accurate += 1
# Calculate the classification rate
accuracy = accurate/tot accuracy = accurate/tot
return accuracy return accuracy
...@@ -42,14 +51,27 @@ def evaluate_knn(data_train, labels_train, data_test, labels_test, k): ...@@ -42,14 +51,27 @@ def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
if __name__ == "__main__": if __name__ == "__main__":
bench_knn() data, labels = read_cifar("./data/cifar-10-batches-py")
# data, labels = read_cifar.read_cifar('image-classification/data/cifar-10-batches-py') data_train, labels_train, data_test, labels_test = split_dataset(data, labels, 0.9)
# X_train, X_test, y_train, y_test = read_cifar.split_dataset(data, labels, 0.9)
# print(evaluate_knn(X_train, y_train, X_test, y_test, 5)) k_list = [k for k in range(1, 21)]
# print(X_train.shape, X_test.shape, y_train.shape, y_test.shape) accuracy = [evaluate_knn(data_train, labels_train, data_test, labels_test, k) for k in range (1, 21)]
plt.plot([k for k in range (1, 21)], accuracy)
plt.title("Variation of k-nearest neighbors method accuracy for k from 1 to 20")
plt.xlabel("k value")
plt.ylabel("Accuracy")
plt.grid(True, which='both')
plt.savefig("results/knn.png")
# y_test = []
# x_test = np.array([[1,2],[4,6]]) # x_test = np.array([[1,2],[4,6]])
# x_labels_test = np.array([0,1])
# x_train = np.array([[2,4],[7,2],[4,6]]) # x_train = np.array([[2,4],[7,2],[4,6]])
# y_train = [1,2,1] # x_labels_train = np.array([0,1,1])
# dist = distance_matrix(x_test, x_train) # dist = distance_matrix(x_test, x_train)
# accuracy = evaluate_knn(x_train, x_labels_train, x_test, x_labels_test, 1)
# print(accuracy)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment