diff --git a/knn.py b/knn.py index f3006bde1c5699e67ce407b33bb71d85af6da9df..511353b0799291f023c4ee61120081f17b68c166 100644 --- a/knn.py +++ b/knn.py @@ -15,12 +15,9 @@ from tqdm import tqdm def distance_matrix(A,B) : sum_of_squaresA= np.sum(A**2, axis = 1, keepdims = True) sum_of_squaresB = np.sum(B**2, axis = 1) - # sum_of_squaresA = np.tile(sum_of_squaresAVect, (np.shape(B)[0], 1)) - # sum_of_squaresB = np.tile(sum_of_squaresBVect, (np.shape(A)[0], 1)) # Calculate the dot product between the two matrices dot_product = np.dot(A, B.T) - # dot_product = np.einsum('ij,jk', A, B.T) # Calculate the Euclidean distance matrix using the hint provided dists = np.sqrt(sum_of_squaresA + sum_of_squaresB - 2 * dot_product) return dists