diff --git a/tests/distance_matrix_test.py b/tests/distance_matrix_test.py new file mode 100644 index 0000000000000000000000000000000000000000..645111728aa512e1d74df62dd1b5b3bb2d66f768 --- /dev/null +++ b/tests/distance_matrix_test.py @@ -0,0 +1,10 @@ +from knn import * +import numpy as np + +def distance_matrix_test(): + train = np.random.rand(100, 1000) + test = np.random.rand(100, 1000) + dists = distance_matrix(train, test) + assert dists.shape == (100, 80) + +distance_matrix_test()