From 6c0a1bf6cd1f344f0fcc91b89b5b4b01a01654b7 Mon Sep 17 00:00:00 2001 From: lucile <lucile.audard@ecl20.ec-lyon.fr> Date: Fri, 10 Nov 2023 16:05:04 +0100 Subject: [PATCH] Update knn.py --- knn.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/knn.py b/knn.py index 2edab7c..a7decbb 100644 --- a/knn.py +++ b/knn.py @@ -43,35 +43,25 @@ def evaluate_knn(data_train, labels_train, data_test, labels_test, k): return accuracy - - - - - - if __name__ == "__main__": - + + # Extraction of the data from Cifar database data, labels = read_cifar("./data/cifar-10-batches-py") - data_train, labels_train, data_test, labels_test = split_dataset(data, labels, 0.9) + # Formatting the data into training and testing sets + data_train, labels_train, data_test, labels_test = split_dataset(data, labels, 0.1) + + # Data to plot k_list = [k for k in range(1, 21)] accuracy = [evaluate_knn(data_train, labels_train, data_test, labels_test, k) for k in range (1, 21)] + # Plot the graph + plt.close() plt.plot(k_list, accuracy) plt.title("Variation of k-nearest neighbors method accuracy for k from 1 to 20") plt.xlabel("k value") plt.ylabel("Accuracy") plt.grid(True, which='both') - plt.savefig("results/knn.png") - + plt.show() + #plt.savefig("results/knn.png") - # x_test = np.array([[1,2],[4,6]]) - # x_labels_test = np.array([0,1]) - # x_train = np.array([[2,4],[7,2],[4,6]]) - # x_labels_train = np.array([0,1,1]) - - # dist = distance_matrix(x_test, x_train) - # accuracy = evaluate_knn(x_train, x_labels_train, x_test, x_labels_test, 1) - # print(accuracy) - - -- GitLab