From 781f0b13ff056f308e5fd6044ec43a3fb6b0737b Mon Sep 17 00:00:00 2001
From: oscarchaufour <101994223+oscarchaufour@users.noreply.github.com>
Date: Fri, 20 Oct 2023 17:13:53 +0200
Subject: [PATCH] Update read_cifar.py

cifar extraction
---
 read_cifar.py | 42 +++++++++++++++++++++++++++++++++++++++---
 1 file changed, 39 insertions(+), 3 deletions(-)

diff --git a/read_cifar.py b/read_cifar.py
index a905f14..4fe8169 100644
--- a/read_cifar.py
+++ b/read_cifar.py
@@ -5,8 +5,44 @@ Created on Fri Oct 20 16:04:49 2023
 @author: oscar
 """
 
-def read_cifar_batch() :
-    return 
+import os
+import numpy as np
+import pickle
+
+import pickle
+
+def read_cifar_batch(batch_path):
+    with open(batch_path, 'rb') as f:
+        batch = pickle.load(f, encoding='bytes')
+        
+        data = np.array(batch.get(b'data'))
+        labels = np.array(batch.get(b'labels'))
+
+ 
+    return data, labels
+
+def read_cifar (batch_dir) :
+    data_batches = []
+    label_batches = []
+    
+    for i in range(1,6) :
+        batch_filename = f'data_batch_{i}'
+        batch_path = os.path.join(batch_dir, batch_filename)
+        data, labels = read_cifar_batch(batch_path)
+        data_batches.append(data)
+        label_batches.append(labels)
+        
+        test_batch_filename = 'test_batch'
+        test_batch_path = os.path.join(batch_dir, test_batch_filename)
+        data_test, labels_test = read_cifar_batch(test_batch_path)
+        data_batches.append(data_test)
+        label_batches.append(labels_test)
+        
+        data = np.concatenate(data_batches, axis=0)
+        labels = np.concatenate(label_batches, axis=0)
+
+    return data, labels
 
 if __name__ == "__main__":
-    pass
+    file = "./data/cifar-10-python/data_batch_1"
+    read_cifar_batch(file)
-- 
GitLab