diff --git a/knn.py b/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5b8ef94591e2d577016b37351e6653f347c23d --- /dev/null +++ b/knn.py @@ -0,0 +1,19 @@ +import numpy as np +import pickle +import os + + +def distance_matrix (M1, M2) : + sum_squares_1 = np.sum(M1**2, axis = 1, keepdims = True) + sum_squares_2 = np.sum(M2**2, axis = 1, keepdims = True) + + dot_product = np.dot(M1, M2.T) + dists = np.sqrt(sum_squares_1 - 2*dot_product + sum_squares_2.T) + + return dists + +split_dataset (data, labels, split) + +def knn_predict (dists, labels_train, k) : + train_set + test_set \ No newline at end of file