From be1fe1ceaba1cff8f9c4d101f86b45cabd8eec24 Mon Sep 17 00:00:00 2001 From: Aya SAIDI <aya.saidi@auditeur.ec-lyon.fr> Date: Sat, 22 Oct 2022 12:11:29 +0100 Subject: [PATCH] Update knn.py --- knn.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/knn.py b/knn.py index a322bea..f0c66e8 100644 --- a/knn.py +++ b/knn.py @@ -1,8 +1,31 @@ +#Libraries +import numpy as np +import torch +#Functions def distance_matrix(Y , X): #This function takes as parameters two matrices X and Y dists = np.sqrt(np.sum(-2 * np.multiply(X, Y)+ np.multiply(Y, Y) + np.multiply(X, X))) #dists is the euclidian distance between two matrices return dists -def knn_predict(dists, labels_train,k): - +def knn_predict(dists, labels_train, k): + #This function takes as parameters: dists (from above), labels_train, and k the number of neighbors + labels_test_pred=torch.zeros(len(data_test), dtype=torch.int64) + for i in range(dists.shape[1]): + # Find index of k lowest values + x = torch.topk(dists[:,i], k, largest=False).indices + + # Index the labels according to x + k_lowest_labels = labels_train[x] + + # y_test_pred[i] = the most frequent occuring index + labels_test_pred[i] = torch.argmax(torch.bincount(k_lowest_labels)) + + return labels_test_pred + +def evaluate_knn(data_train, labels_train, data_test, labels_test, k): + labels_test_pred=knn_predict(distance_matrix(data_train, data_test), labels_train, k) + num_samples= data_test.shape[0] + num_correct= (labels_test == labels_test_pred).sum().item() + accuracy= 100 * num_correct / num_samples + return accuracy -- GitLab