diff --git a/knn.py b/knn.py index f0c66e874a8039128389c3cc232fc62b24b477cf..32254851b7b3a770e3f3555a94229a52d3647092 100644 --- a/knn.py +++ b/knn.py @@ -4,9 +4,13 @@ import torch #Functions def distance_matrix(Y , X): #This function takes as parameters two matrices X and Y - dists = np.sqrt(np.sum(-2 * np.multiply(X, Y)+ np.multiply(Y, Y) + np.multiply(X, X))) - #dists is the euclidian distance between two matrices - return dists + a_2=(Y**2).sum(axis=1) + a_2=a_2.reshape(-1,1) + b_2=(X**2).sum(axis=1) + b_2=b_2.reshape(1,-1) + dist = np.sqrt(a_2 + b_2 -2*Y.dot(X.T)) + #dist is the euclidian distance between two matrices + return dist def knn_predict(dists, labels_train, k): #This function takes as parameters: dists (from above), labels_train, and k the number of neighbors