From 5aabffaade4db6f75005aaca53ff0f06138c1c2c Mon Sep 17 00:00:00 2001
From: Milan <milan.cart@ecl20.ec-lyon.fr>
Date: Thu, 9 Nov 2023 23:34:03 +0100
Subject: [PATCH] Modif

---
 knn.py         |  46 +++++++++++++---------------------------------
 mlp.py         |   0
 result/knn.png | Bin 0 -> 2396 bytes
 3 files changed, 13 insertions(+), 33 deletions(-)
 create mode 100644 mlp.py
 create mode 100644 result/knn.png

diff --git a/knn.py b/knn.py
index 0a687fb..3ea3ce9 100644
--- a/knn.py
+++ b/knn.py
@@ -3,58 +3,47 @@ import numpy as np
 import matplotlib.pyplot as plt
 
 def distance_matrix(matrix1, matrix2):
-    #X_test then X_train in this order
-    sum_of_squares_matrix1 = np.sum(np.square(matrix1), axis=1, keepdims=True) #A^2
-    sum_of_squares_matrix2 = np.sum(np.square(matrix2), axis=1, keepdims=True) #B^2
+    sum_of_squares_matrix1 = np.sum(np.square(matrix1), axis=1, keepdims=True) 
+    sum_of_squares_matrix2 = np.sum(np.square(matrix2), axis=1, keepdims=True) 
 
-    dot_product = np.dot(matrix1, matrix2.T) # A * B (matrix mutliplication)
+    dot_product = np.dot(matrix1, matrix2.T)
     
-    dists = np.sqrt(sum_of_squares_matrix1 + sum_of_squares_matrix2.T - 2 * dot_product) # Compute the product
+    dists = np.sqrt(sum_of_squares_matrix1 + sum_of_squares_matrix2.T - 2 * dot_product)
     return dists
 
 def knn_predict(dists, labels_train, k):
     output = []
-    # Loop on all the images_test
+    
     for i in range(len(dists)):
-        # Innitialize table to store the neighbors
+        
         res = [0] * 10
-        # Get the closest neighbors
         labels_close = np.argsort(dists[i])[:k]
         for label in labels_close:
-            #add a label to the table of result
             res[labels_train[label]] += 1
-        # Get the class with the maximum neighbors
-        label_temp = np.argmax(res) #Careful to the logic here, if there is two or more maximum, the function the first maximum encountered
+        label_temp = np.argmax(res) 
         output.append(label_temp)
     return(np.array(output))
 
 def evaluate_knn(data_train, labels_train, data_test, labels_tests, k):
     dist = distance_matrix(data_test, data_train)
     result_test = knn_predict(dist, labels_train, k)
-
-    #accuracy 
+ 
     N = labels_tests.shape[0]
     accuracy = (labels_tests == result_test).sum() / N
     return(accuracy)
 
-def bench_knn() :
+def test_knn() :
 
     k_indices = [i for i in range(20) if i % 2 != 0]
     accuracies = []
 
-    # Load data
     data, labels = read_cifar.read_cifar('/Users/milancart/Documents/GitHub/image-classification/Data/cifar-10-batches-py')
     X_train, X_test, y_train, y_test = read_cifar.split_dataset(data, labels, 0.9)
-    #Load one batch
-    # data, labels = read_cifar.read_cifar_batch('image-classification/data/cifar-10-batches-py/data_batch_1')
-    # X_train, X_test, y_train, y_test = read_cifar.split_dataset(data, labels, 0.9)
-
-    # Loop on the k_indices to get all the accuracies
+    
     for k in k_indices :
         accuracy = evaluate_knn(X_train, y_train, X_test, y_test, k)
         accuracies.append(accuracy)
     
-    # Save and show the graph of accuracies
     plt.figure(figsize=(8, 6))
     plt.xlabel('K')
     plt.ylabel('Accuracy')
@@ -66,15 +55,6 @@ def bench_knn() :
 
 
 if __name__ == "__main__":
-    print('milan')
-    bench_knn()
-    data, labels = read_cifar.read_cifar('/Users/milancart/Documents/GitHub/image-classification/Data/cifar-10-batches-py')
-    X_train, X_test, y_train, y_test = read_cifar.split_dataset(data, labels, 0.9)
-    print(evaluate_knn(X_train, y_train, X_test, y_test, 5))
-    print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
-
-    y_test = []
-    x_test = np.array([[1,2],[4,6]])
-    x_train = np.array([[2,4],[7,2],[4,6]])
-    y_train = [1,2,1]
-    dist = distance_matrix(x_test,x_train)      
\ No newline at end of file
+    
+    test_knn()
+    
\ No newline at end of file
diff --git a/mlp.py b/mlp.py
new file mode 100644
index 0000000..e69de29
diff --git a/result/knn.png b/result/knn.png
new file mode 100644
index 0000000000000000000000000000000000000000..72354204692aa86cfb0bce31acc22fdbb378fc2f
GIT binary patch
literal 2396
zcmeAS@N?(olHy`uVBq!ia0y~yU}|7sV0^&A#=yW}dhyN^1_lPp64!{5;QX|b^2DN4
z2H(Vzf}H%4oXjMJvecsD%=|oKJ##%H9fgdNl7eC@ef?ax0=@jAbbZa(r}YdB44efX
zk;M!Q{D~mUxWayUCIbV<NlzEYkcv5PuNg8jDDWIGII#cuL#948uBx*0XLJ}CjLyAh
zU|_gs)WE=C$R@y`aEyV8;fRC-LxTh}2ZKN|10zEcj{<{34<idhhvBHA(I6O21)~{Z
ov@94c4o7Q*(Mpl7?E*cuO7=a=mv2g9U|?YIboFyt=akR{0QjPxK>z>%

literal 0
HcmV?d00001

-- 
GitLab