Skip to content

Wrong knn prediction

When I run

from knn import knn_predict
import numpy as np

dists = np.array(
    [
        [0.1, 0.3, 0.9, 1.0],
        [0.5, 0.2, 0.9, 0.2],
        [0.8, 0.0, 0.4, 0.4],
        [0.2, 0.5, 0.8, 0.6],
        [0.3, 0.5, 0.7, 0.8],
    ]
)  # shape [n_train, n_test]
labels_train = np.array([0, 0, 2, 2, 2])

pred_labels = knn_predict(dists, labels_train, k=3)
assert np.all(pred_labels == np.array([2, 0, 2, 2]))

I get

/Users/quentingallouedec/image-classification/test.py:16: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
  assert np.all(pred_labels == np.array([2, 0, 2, 2]))
Traceback (most recent call last):
  File "/Users/quentingallouedec/image-classification/test.py", line 16, in <module>
    assert np.all(pred_labels == np.array([2, 0, 2, 2]))
AssertionError