From 6c0a1bf6cd1f344f0fcc91b89b5b4b01a01654b7 Mon Sep 17 00:00:00 2001
From: lucile <lucile.audard@ecl20.ec-lyon.fr>
Date: Fri, 10 Nov 2023 16:05:04 +0100
Subject: [PATCH] Update knn.py

---
 knn.py | 30 ++++++++++--------------------
 1 file changed, 10 insertions(+), 20 deletions(-)

diff --git a/knn.py b/knn.py
index 2edab7c..a7decbb 100644
--- a/knn.py
+++ b/knn.py
@@ -43,35 +43,25 @@ def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
     return accuracy
 
 
-
-
-
-
-
-
 if __name__ == "__main__":
-
+    
+    # Extraction of the data from Cifar database
     data, labels = read_cifar("./data/cifar-10-batches-py")
-    data_train, labels_train, data_test, labels_test = split_dataset(data, labels, 0.9)
     
+    # Formatting the data into training and testing sets
+    data_train, labels_train, data_test, labels_test = split_dataset(data, labels, 0.1)
+    
+    # Data to plot
     k_list = [k for k in range(1, 21)]
     accuracy = [evaluate_knn(data_train, labels_train, data_test, labels_test, k) for k in range (1, 21)]
     
+    # Plot the graph
+    plt.close()
     plt.plot(k_list, accuracy)
     plt.title("Variation of k-nearest neighbors method accuracy for k from 1 to 20")
     plt.xlabel("k value")
     plt.ylabel("Accuracy")
     plt.grid(True, which='both')
-    plt.savefig("results/knn.png")
-
+    plt.show()
+    #plt.savefig("results/knn.png")
 
-    # x_test = np.array([[1,2],[4,6]])
-    # x_labels_test = np.array([0,1])
-    # x_train = np.array([[2,4],[7,2],[4,6]])
-    # x_labels_train = np.array([0,1,1])
-
-    # dist = distance_matrix(x_test, x_train)
-    # accuracy = evaluate_knn(x_train, x_labels_train, x_test, x_labels_test, 1)
-    # print(accuracy)
-
-    
-- 
GitLab