From eec797e060940e02881f3ffc8e64f06d612d6448 Mon Sep 17 00:00:00 2001
From: Aya SAIDI <aya.saidi@auditeur.ec-lyon.fr>
Date: Sat, 29 Oct 2022 23:07:33 +0100
Subject: [PATCH] Update knn.py

---
 knn.py | 10 +++++++---
 1 file changed, 7 insertions(+), 3 deletions(-)

diff --git a/knn.py b/knn.py
index f0c66e8..3225485 100644
--- a/knn.py
+++ b/knn.py
@@ -4,9 +4,13 @@ import torch
 #Functions
 def distance_matrix(Y , X):
     #This function takes as parameters two matrices X and Y
-    dists = np.sqrt(np.sum(-2 * np.multiply(X, Y)+ np.multiply(Y, Y) + np.multiply(X, X)))
-    #dists is the euclidian distance between two matrices
-    return dists
+    a_2=(Y**2).sum(axis=1)
+    a_2=a_2.reshape(-1,1)
+    b_2=(X**2).sum(axis=1)
+    b_2=b_2.reshape(1,-1)
+    dist = np.sqrt(a_2 + b_2 -2*Y.dot(X.T))
+    #dist is the euclidian distance between two matrices
+    return dist
 
 def knn_predict(dists, labels_train, k):
     #This function takes as parameters: dists (from above), labels_train, and k the number of neighbors
-- 
GitLab