From 0e07e0d31258bf48c9b4e9d0abd4e56ffc05e821 Mon Sep 17 00:00:00 2001
From: BaptisteBrd <75663738+BaptisteBrd@users.noreply.github.com>
Date: Thu, 9 Nov 2023 00:50:09 +0100
Subject: [PATCH] fonction read cifar

---
 read_cifar.py | 25 +++++++++++++++++++++++--
 1 file changed, 23 insertions(+), 2 deletions(-)

diff --git a/read_cifar.py b/read_cifar.py
index abba3fa..2aeb021 100644
--- a/read_cifar.py
+++ b/read_cifar.py
@@ -7,8 +7,29 @@ def read_cifar_batch(file):
         labels = np.array(dict[b'labels']).astype('int64')
     return data, labels
 
-vect1= read_cifar_batch("data/cifar-10-batches-py/data_batch_1")
+#vect1= read_cifar_batch("data/cifar-10-batches-py/data_batch_1")
 
 #print(vect1)
 
-#def read_cifar
\ No newline at end of file
+def read_cifar(directory):
+    all_data = []
+    all_labels = []
+
+    for i in range(1,6):
+        data_v, labels_v = read_cifar_batch(f'{directory}/data_batch_{i}')
+        all_data.append(data_v)
+        all_labels.append(labels_v)
+
+    data_v, labels_v = read_cifar_batch(f'{directory}/test_batch')
+    all_data.append(data_v)
+    all_labels.append(labels_v)
+
+    all_data = np.concatenate(all_data, axis = 0)
+    all_labels = np.concatenate(all_labels, axis = 0)
+
+    return(all_data, all_labels)
+
+
+#vect2= read_cifar("data/cifar-10-batches-py")
+#print(vect2)
+
-- 
GitLab