From fb477c854edde35376ddd6badedd9fae7c12e915 Mon Sep 17 00:00:00 2001 From: Danjou <pierre.danjou@etu.ec-lyon.fr> Date: Mon, 11 Nov 2024 19:15:40 +0100 Subject: [PATCH] Update knn.py --- knn.py | 38 ++------------------------------------ 1 file changed, 2 insertions(+), 36 deletions(-) diff --git a/knn.py b/knn.py index 672503a..ef68437 100644 --- a/knn.py +++ b/knn.py @@ -6,17 +6,7 @@ import matplotlib.pyplot as plt import numpy as np def distance_matrix(A, B): - """ - Compute the L2 Euclidean distance matrix between two matrices A and B. - - Parameters: - A (numpy.ndarray): Matrix of shape (m, n) - B (numpy.ndarray): Matrix of shape (p, n) - - Returns: - numpy.ndarray: Distance matrix of shape (m, p) where the element (i, j) is the - Euclidean distance between A[i] and B[j]. - """ + # Squared norms of each row in A and B A_squared = np.sum(A**2, axis=1).reshape(-1, 1) # Shape (m, 1) B_squared = np.sum(B**2, axis=1).reshape(1, -1) # Shape (1, p) @@ -55,19 +45,7 @@ def evaluate_knn(data_train, labels_train, data_test, labels_tests, k): accuracy = (labels_tests == result_test).sum() / N return(accuracy) -# def evaluate_knn(data_train, labels_train, data_test, labels_test, k): -# dists = distance_matrix(data_test, data_train) -# # Determine the number of images in data_test -# tot = len(data_test) -# accurate = 0 -# predicted_labels = knn_predict(dists, labels_train, k) -# # Count the number of images in data_test whose label has been estimated correctly -# for i in range(tot): -# if predicted_labels[i] == labels_test[i]: -# accurate += 1 -# # Calculate the classification rate -# accuracy = accurate/tot -# return accuracy + if __name__ == "__main__": @@ -88,15 +66,3 @@ if __name__ == "__main__": print(accurancy) -# data, labels = read_cifar('data\cifar-10-batches-py') - - -# data_train, data_test, labels_train, labels_test = split_dataset(data, labels, 0.9) - -# k=3 -# accurancies = [] - -# accurancy = evaluate_knn(data_train, data_test, labels_train, labels_test, k) -# accurancies.append(accurancy) - -# print(accurancies) \ No newline at end of file -- GitLab