From eec797e060940e02881f3ffc8e64f06d612d6448 Mon Sep 17 00:00:00 2001 From: Aya SAIDI <aya.saidi@auditeur.ec-lyon.fr> Date: Sat, 29 Oct 2022 23:07:33 +0100 Subject: [PATCH] Update knn.py --- knn.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/knn.py b/knn.py index f0c66e8..3225485 100644 --- a/knn.py +++ b/knn.py @@ -4,9 +4,13 @@ import torch #Functions def distance_matrix(Y , X): #This function takes as parameters two matrices X and Y - dists = np.sqrt(np.sum(-2 * np.multiply(X, Y)+ np.multiply(Y, Y) + np.multiply(X, X))) - #dists is the euclidian distance between two matrices - return dists + a_2=(Y**2).sum(axis=1) + a_2=a_2.reshape(-1,1) + b_2=(X**2).sum(axis=1) + b_2=b_2.reshape(1,-1) + dist = np.sqrt(a_2 + b_2 -2*Y.dot(X.T)) + #dist is the euclidian distance between two matrices + return dist def knn_predict(dists, labels_train, k): #This function takes as parameters: dists (from above), labels_train, and k the number of neighbors -- GitLab