From cfaf9737cbb14ee2cad40813abcdf529ce3895ca Mon Sep 17 00:00:00 2001
From: Malo Bourry <malo.bourry@ecl20.ec-lyon.fr>
Date: Fri, 20 Oct 2023 22:09:51 +0200
Subject: [PATCH] Partie knn finie

---
 __pycache__/read_cifar.cpython-38.pyc | Bin 0 -> 1871 bytes
 knn.py                                |  42 +++++++++++++++++++++++---
 read_cifar.py                         |   6 ++--
 3 files changed, 41 insertions(+), 7 deletions(-)
 create mode 100644 __pycache__/read_cifar.cpython-38.pyc

diff --git a/__pycache__/read_cifar.cpython-38.pyc b/__pycache__/read_cifar.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d01a55faa65f9a5c3e47242daef27b820c55d340
GIT binary patch
literal 1871
zcmWIL<>g{vU|=Y_W0Wew!NBks#6iX^3=9ko3=9m#c?=8;DGVu$ISf&ZV45k4DU~^e
zX%1rwa|%lfOB72AYYJNnLlkQ&V-_12v!}4nVNT&l;cQ`v;z;F8;mT$z%1UKzW{6>m
z;)>!<<w)g9<xb&lW{zQr;)&vo;)~)>6-eP}W@KbYVGL%_<b4TpiJvCpE#88}l8pHL
zwD_dNlH`nJMvyoZvobI+a56A3ID;G+!N9;!!jQsP!ywL(!qm*vFH*}`!kEQW19oID
zQ!P_CgDFEGLl8p*LkV*hOAS*Bt0aiUn!?u0RLcw%V+V_|fyG#Am}^+ln1UHJIsAS}
zrX-dm>L+KWB^K!#8t8&ulUl4>P^o{5sVM0dYf@!NYVk_OTO6r*$@wXndFh`);Rn*B
z$#{!3rKGYT^)nj-14B+?Qff}ICi^Xx{DRcHTWkfH$=Nxnw^(xW6H{(6<rUmwO)M%(
zth~jZmXn`YVr+DaH8Zco%tVv5h>3xL;TBU`{w<c2%;b_=EFiOQu|dpR$xtN5z`*dU
zBiSk@v^ce>IL0?ICqKp|Ke;qFHLs*N#yK^wq$n{bHO951G$pk-#yP(fq{!c;IL0%t
zq$oe7G`S=*KTjdCRKX=RwLrlqHL)l!GcP?RGdD3kH9k2fvA8%hEi*Y0qzWRaizZl<
znwSy~4j{dP%3A^mE+pze!BGrOaf||t609POB8){63=9m((AWf191IK$AT{79m1AUJ
zNMWpHNN1>JOlPQNDgi|-qYFc9L@jd-LkhDvLkVLIa|(+%Lo=f|LoG`QQw>WBD>&*|
z!O_W@!j{6`%T&vTi254F1*|n}DIAgv3mF+1YM8)0P6)41q=q4dOPryWy@aWTof4Bu
z*cPzYFqW`oaV%u2Wd!q>O4zbE7c$i{m9Q<~0?94lu3=2!UI^lYSS$;fL>Nlgz@j`0
znbMfRGF&wfc_z3@rWCMBkP1*f@hf6xU|@I&BC2@wK}iA<+VL+zIi&~`=tZD>QUpqS
zMWEm<0)=lCw|+@#aS23i5lHDvQ1a5`xy4$Pn3tY<i@CU@sEC_^fuV>8MDT*r9$R8T
zL26z~5g$l|A0z@vq=sfif?!@zYH>zlLFz4*;>@a4O^#cfDVas7$tC$kl|>>TJ)$60
zjG0kf;5-e^LPbI#31J2XhA4ipD3Y^4DH@!ti^M@1xezG|WO=a?D8(?cFtISQF!C^p
zF!C@;F!3=8Fmr%djC@d7q{6_!pvjEn0M_DyoXnDBP>_L40AWy&u)|7{1>lk-i*W%{
z3Bv+VLStOWw2%>$<Z75D8ERQ-7{N4a4ND4R2@66^4U-5%Gh-{0BttD*4J%xZF@<p<
z6OtTz4GX5O6b7(9jv7{|9Fi`m9A^z%4SNko4QColFoPzOA2`-DS;4V!i>)XzFC{<s
z7I#5vQEq7oIF}YFf<lcsGp|IG3*r4DRZv*-6=#&DrRAii#AoKEq*g@n!ug;i6vYqc
zLW-ALTp;22lA^@SyjwgFF({7{B9mHNa*G=#ky=~=i8pWzf+LR?9O&^N#l@*5>Yx%^
z02K0!B8(i29E>s$$W){Yax+JKd~RZ9UVMD|teGJ2(!NLv6pMnlSo2DA3o5~Wh9pc`
zQ1Ss4Bj9+x#h#Y}DLz@j#V6djyu{qp_;{pH&;;qz0ukCELI-3mYe7+F9<s|3&gZbn
g%}*)KNwouIm12;~IT$&Z__#SF#CSLuxfn$l0qtwa*#H0l

literal 0
HcmV?d00001

diff --git a/knn.py b/knn.py
index e499927..b39e134 100644
--- a/knn.py
+++ b/knn.py
@@ -1,4 +1,6 @@
 import numpy as np
+from read_cifar import read_cifar, split_dataset
+import matplotlib.pyplot as plt
 
 def distance_matrix(matrix_a: np.ndarray, matrix_b: np.ndarray):
     sum_squares_1 = np.sum(matrix_a**2, axis = 1, keepdims = True)
@@ -10,11 +12,43 @@ def distance_matrix(matrix_a: np.ndarray, matrix_b: np.ndarray):
     return dists
 
 def knn_predict(dists: np.ndarray, labels_train: np.ndarray, k:int):
-    return 0
+    labels_predicts = np.zeros(np.size(dist, 0))
+    for i in range(np.size(labels_predicts, 0)):
+        #On extrait les indices des k valeurs plus petites (des k plus proches voisins)
+        k_neighbors_index = np.argmin(dists[i, :], np.sort(dists[i, :])[:k])
+        #On compte la classe la plus présente parmi les k voisins les plus proches
+        labels_k_neighbors = labels_train[k_neighbors_index]
+        #On compte le nombre d'occurence des classes parmis les k
+        _, count = np.unique(labels_k_neighbors, return_counts=True)
+        #On associe à la prédiction la classe la plus presente parmis les k
+        labels_predicts[i] = labels_k_neighbors[np.argmax(count)]
+    return labels_predicts
+
+def evaluate_knn(data_train:np.ndarray, labels_train: np.ndarray, data_test:np.ndarray, labels_test:np.ndarray, k:int):
+    dists = distance_matrix(data_test, data_train)
+    labels_predicts = knn_predict(dists, labels_train, k)
+    #calcul de l'accuracy
+    accuracy = 0
+    for i in range(np.size(labels_predicts, 0)):
+        if abs(labels_predicts[i]-labels_test[i])<10**(-7):
+            accuracy += 1
+    accuracy /= np.size(labels_predicts, 0)
+    return accuracy
+
+def plot_knn(data_train:np.ndarray, labels_train: np.ndarray, data_test:np.ndarray, labels_test:np.ndarray, n: int):
+    accuracy_vector = np.zeros(n)
+    for k in range(1, n+1):
+        accuracy_vector[k] = evaluate_knn(data_train, labels_train, data_test, labels_test)
+    plt.plot(accuracy_vector)
+    plt.show()
+    return
+
+
 
 
 
 if __name__ == "__main__":
-    A = np.ones((3,3))
-    B = np.ones((3,3))*2
-    dist = distance_matrix(A, B)
\ No newline at end of file
+    data, labels = read_cifar()
+    data_train, labels_train, data_test, labels_test = split_dataset(data, labels, 0.8)
+    k = 5 #Nombre de voisins
+    accuracy = evaluate_knn(data_train, labels_train, data_test, labels_test, k)
\ No newline at end of file
diff --git a/read_cifar.py b/read_cifar.py
index 5d1e780..df7393f 100644
--- a/read_cifar.py
+++ b/read_cifar.py
@@ -31,9 +31,9 @@ def read_cifar():
         dict = pickle.load(fo, encoding='bytes')
     data.append(dict[b'data'])
     labels.append(dict[b'labels'])
-    data = np.array(data, np.float32)
+    data = np.array(data, np.float16)
     labels = np.array(labels, np.int64)
-    return np.reshape(data, (np.size(data, 0)*np.size(data, 1), np.size(data, 2))), np.reshape(labels, (np.size(labels, 0)*np.size(labels, 1), 1))
+    return np.reshape(data, (np.size(data, 0)*np.size(data, 1), np.size(data, 2))), np.reshape(labels, (np.size(labels, 0)*np.size(labels, 1)))
 
 
 def split_dataset(data: np.ndarray, labels: np.ndarray, split: float):
@@ -50,5 +50,5 @@ def split_dataset(data: np.ndarray, labels: np.ndarray, split: float):
 
 if __name__ == "__main__":
     data, labels = read_cifar()
-    a, b, c, d = split_dataset(data, labels, 0.8)
+    data_train, labels_train, data_test, labels_test = split_dataset(data, labels, 0.8)
     print(1)
-- 
GitLab