Skip to content
Snippets Groups Projects
Commit eec797e0 authored by Saidi Aya's avatar Saidi Aya
Browse files

Update knn.py

parent 0c096b6e
No related branches found
No related tags found
No related merge requests found
...@@ -4,9 +4,13 @@ import torch ...@@ -4,9 +4,13 @@ import torch
#Functions #Functions
def distance_matrix(Y , X): def distance_matrix(Y , X):
#This function takes as parameters two matrices X and Y #This function takes as parameters two matrices X and Y
dists = np.sqrt(np.sum(-2 * np.multiply(X, Y)+ np.multiply(Y, Y) + np.multiply(X, X))) a_2=(Y**2).sum(axis=1)
#dists is the euclidian distance between two matrices a_2=a_2.reshape(-1,1)
return dists b_2=(X**2).sum(axis=1)
b_2=b_2.reshape(1,-1)
dist = np.sqrt(a_2 + b_2 -2*Y.dot(X.T))
#dist is the euclidian distance between two matrices
return dist
def knn_predict(dists, labels_train, k): def knn_predict(dists, labels_train, k):
#This function takes as parameters: dists (from above), labels_train, and k the number of neighbors #This function takes as parameters: dists (from above), labels_train, and k the number of neighbors
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment