From 80d13b1745cb6cace6e78647f57d7f5a917ddea8 Mon Sep 17 00:00:00 2001
From: Elkhadri Doha <doha.elkhadri@etu.ec-lyon.fr>
Date: Fri, 10 Nov 2023 06:07:30 +0000
Subject: [PATCH] Upload New File

---
 read_cifar.py | 62 +++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 62 insertions(+)
 create mode 100644 read_cifar.py

diff --git a/read_cifar.py b/read_cifar.py
new file mode 100644
index 0000000..749f715
--- /dev/null
+++ b/read_cifar.py
@@ -0,0 +1,62 @@
+import numpy as np
+import pickle
+import os
+
+def read_cifar_batch(batch_path):
+    with open(batch_path, 'rb') as file:
+        #Read the binary data from the file and load it as a Python object
+        batch = pickle.load(file, encoding='bytes') 
+
+    #Convert extracted data and labels from the batch to NumPy arrays
+    data = np.array(batch[b'data'], dtype=np.float32)
+    labels = np.array(batch[b'labels'], dtype=np.int64)
+
+    return data, labels
+    
+
+def read_cifar(directory_path):
+    #directory_path contains five data_batch and one test_batch
+    data_list=[]
+    labels_list=[]
+
+    # Load data and labels from each batch file
+    for filename in os.listdir(directory_path):
+        if filename.startswith('data_batch') or filename == 'test_batch':
+            batch_path = os.path.join(directory_path, filename)
+            with open(batch_path, 'rb') as file:
+                batch = pickle.load(file, encoding='bytes')
+                data_list.append(batch[b'data'])
+                labels_list.extend(batch[b'labels'])
+                
+    # Combine data and labels from all batches
+    data = np.concatenate(data_list, axis=0)
+    labels = np.array(labels_list, dtype=np.int64)
+
+    # Convert data to np.float32
+    data = data.astype(np.float32)
+
+    return data, labels
+
+
+def split_dataset(data, labels, split):
+    if split < 0.0 or split > 1.0:
+        raise ValueError("Split value must be between 0 and 1 ")
+
+    #Number of training samples
+    num_samples = data.shape[0]
+    num_train_samples = int(num_samples * split)
+
+    # Create a random permutation of indices for shuffling
+    shuffled_indices = np.random.permutation(num_samples)
+
+    #Split the indices into training and test indices
+    train_indices = shuffled_indices[ :num_train_samples]
+    test_indices = shuffled_indices[num_train_samples:]
+
+    # Use the shuffled indices to split the data and labels
+    data_train = data[train_indices]
+    labels_train = labels[train_indices]
+    data_test = data[test_indices]
+    labels_test = labels[test_indices]
+
+    return data_train, labels_train, data_test, labels_test
-- 
GitLab