From 8cb839500de9cc2e513e6fbc58fc624248d43154 Mon Sep 17 00:00:00 2001 From: BaptisteBrd <75663738+BaptisteBrd@users.noreply.github.com> Date: Fri, 10 Nov 2023 23:18:05 +0100 Subject: [PATCH] modif knn --- knn.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/knn.py b/knn.py index 9e44028..365c081 100644 --- a/knn.py +++ b/knn.py @@ -3,26 +3,21 @@ import read_cifar import matplotlib.pyplot as plt def distance_matrix(A,B): - + # Calculating the squared sum of elements in each row for matrices A and B sum_of_squares_A = np.sum(A**2, axis=1,keepdims=True) sum_of_squares_B = np.sum(B**2, axis=1,keepdims=True).T dot_product = np.dot(A, B.T) - + # Computing the Euclidean distance matrix dists=np.sqrt(sum_of_squares_A+sum_of_squares_B-2*dot_product) return dists - - -#def knn_predict(dists, labels_train, k): - # - # def knn_predict(dists, labels_train, k): predicted_labels = [] - # For every image in the test set + # Iterating through each test data point's distances to train data for i in range(len(dists)): - # Initialize an array to store the neighbors + # Counting the frequency of each class among the k nearest neighbors classes = [0] * 10 # indexes of the closest neighbors indexes_closest_nb = np.argsort(dists[i])[:k] @@ -37,8 +32,7 @@ def evaluate_knn(data_train, labels_train, data_test, labels_test, k): rate = 0 dist_train_test = distance_matrix(data_test, data_train) prediction = knn_predict(dist_train_test, labels_train, k) - print(len(prediction)) - print(len(labels_test)) + # Comparing predictions to actual test labels to calculate accuracy for j in range(len(prediction)): if prediction[j]==labels_test[j]: rate +=1 @@ -51,12 +45,16 @@ def knn_final(): data,labels = read_cifar.read_cifar("data/cifar-10-batches-py") data_train_f, labels_train_f, data_test_f, labels_test_f = read_cifar.split_dataset(data, labels, 0.9) + + # Testing KNN for different values of k and storing the accuracy for k in range_k : print(k) rate_k = evaluate_knn(data_train_f, labels_train_f, data_test_f, labels_test_f, k) rates.append(rate_k) + # Plotting the accuracy as a function of k + plt.figure(figsize=(10, 7)) plt.xlabel('k') plt.ylabel('Accuracy rate') @@ -72,7 +70,4 @@ def knn_final(): if __name__ == "__main__" : - knn_final() - #a1 = np.array([[0,0,1],[0,0,0],[1,1,2]]) - #b1 = np.array([[1,3,1], [1,1,4], [1,5,1]]) - #print(distance_matrix(a1,b1)) \ No newline at end of file + knn_final() \ No newline at end of file -- GitLab