diff --git a/knn.py b/knn.py
index f0c66e874a8039128389c3cc232fc62b24b477cf..32254851b7b3a770e3f3555a94229a52d3647092 100644
--- a/knn.py
+++ b/knn.py
@@ -4,9 +4,13 @@ import torch
 #Functions
 def distance_matrix(Y , X):
     #This function takes as parameters two matrices X and Y
-    dists = np.sqrt(np.sum(-2 * np.multiply(X, Y)+ np.multiply(Y, Y) + np.multiply(X, X)))
-    #dists is the euclidian distance between two matrices
-    return dists
+    a_2=(Y**2).sum(axis=1)
+    a_2=a_2.reshape(-1,1)
+    b_2=(X**2).sum(axis=1)
+    b_2=b_2.reshape(1,-1)
+    dist = np.sqrt(a_2 + b_2 -2*Y.dot(X.T))
+    #dist is the euclidian distance between two matrices
+    return dist
 
 def knn_predict(dists, labels_train, k):
     #This function takes as parameters: dists (from above), labels_train, and k the number of neighbors