From b4689a40a92bdeb310934c427b07c0243f408351 Mon Sep 17 00:00:00 2001 From: Aya SAIDI <aya.saidi@auditeur.ec-lyon.fr> Date: Sun, 6 Nov 2022 16:58:12 +0100 Subject: [PATCH] Update knn.py --- knn.py | 45 ++++++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/knn.py b/knn.py index 3225485..735e503 100644 --- a/knn.py +++ b/knn.py @@ -1,6 +1,5 @@ #Libraries import numpy as np -import torch #Functions def distance_matrix(Y , X): #This function takes as parameters two matrices X and Y @@ -14,22 +13,38 @@ def distance_matrix(Y , X): def knn_predict(dists, labels_train, k): #This function takes as parameters: dists (from above), labels_train, and k the number of neighbors - labels_test_pred=torch.zeros(len(data_test), dtype=torch.int64) - for i in range(dists.shape[1]): - # Find index of k lowest values - x = torch.topk(dists[:,i], k, largest=False).indices - - # Index the labels according to x - k_lowest_labels = labels_train[x] - - # y_test_pred[i] = the most frequent occuring index - labels_test_pred[i] = torch.argmax(torch.bincount(k_lowest_labels)) - - return labels_test_pred + labels_pred=np.zeros(labels_train.shape[0]) + for i in range(0,dists.shape[0]): + # Find index of k smallest distances + index_smallest_distance = np.argsort(dists[i,:])[0:k+1] + # Index the labels according to these distances + labels_distances = [labels_train[i] for i in index_smallest_distance] + #Predict the class / label + labels_pred[i]=max(labels_distances,key=labels_distances.count) + return labels_pred def evaluate_knn(data_train, labels_train, data_test, labels_test, k): - labels_test_pred=knn_predict(distance_matrix(data_train, data_test), labels_train, k) + #This function evaluates the knn classifier rate + labels_test__pred=knn_predict(distance_matrix(data_train, data_test), labels_train, k) num_samples= data_test.shape[0] num_correct= (labels_test == labels_test_pred).sum().item() - accuracy= 100 * num_correct / num_samples + accuracy= 100 * (num_correct / num_samples) #The accuracy is the percentage of the correctly predicted classes return accuracy + + +def accuracy_graph(k,dirname,num_batch): + #This function is used to plot the variation of the accuracy as a function of k + # k -- the max number of neighbors + x=[] #axis x : k + y=[] #axis y : accuracy + dir_batch=str(dirname)+"\\data\\cifar-10-batches-py\\data_batch_"+str(num_batch) + dir_test = str(dirname)+"\\data\\cifar-10-batches-py\\test_batch" + (data_test, labels_test)=read_cifar_batch(dir_test) + (data_train, labels_train)=read_cifar_batch(dir_batch) + for i in range (1,k+1): + x.append(i) #axis (k from 1 to 20) + accuracy=evaluate_knn(data_train , labels_train , data_test , labels_test , i) + y.append(accuracy) + plt.plot(x,y) + plt.Show + plt.savefig(str(dirname)+"results") -- GitLab