Skip to content
Snippets Groups Projects
Commit be1fe1ce authored by Saidi Aya's avatar Saidi Aya
Browse files

Update knn.py

parent 08be5a9b
No related branches found
No related tags found
No related merge requests found
#Libraries
import numpy as np
import torch
#Functions
def distance_matrix(Y , X):
#This function takes as parameters two matrices X and Y
dists = np.sqrt(np.sum(-2 * np.multiply(X, Y)+ np.multiply(Y, Y) + np.multiply(X, X)))
......@@ -5,4 +9,23 @@ def distance_matrix(Y , X):
return dists
def knn_predict(dists, labels_train, k):
#This function takes as parameters: dists (from above), labels_train, and k the number of neighbors
labels_test_pred=torch.zeros(len(data_test), dtype=torch.int64)
for i in range(dists.shape[1]):
# Find index of k lowest values
x = torch.topk(dists[:,i], k, largest=False).indices
# Index the labels according to x
k_lowest_labels = labels_train[x]
# y_test_pred[i] = the most frequent occuring index
labels_test_pred[i] = torch.argmax(torch.bincount(k_lowest_labels))
return labels_test_pred
def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
labels_test_pred=knn_predict(distance_matrix(data_train, data_test), labels_train, k)
num_samples= data_test.shape[0]
num_correct= (labels_test == labels_test_pred).sum().item()
accuracy= 100 * num_correct / num_samples
return accuracy
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment