diff --git a/read_cifar.py b/read_cifar.py
new file mode 100644
index 0000000000000000000000000000000000000000..116997425444f4baf5650ad8abe1ca554adf2e6d
--- /dev/null
+++ b/read_cifar.py
@@ -0,0 +1,85 @@
+import numpy as np
+import pickle
+import os
+
+
+def unpickle(file):
+    with open(file, 'rb') as fo:
+        dict = pickle.load(fo, encoding='bytes')
+    return dict
+
+
+def read_cifar_batch(file):
+    dict = unpickle(file)
+    
+    data = np.array(dict[b'data'], dtype=np.float32)
+    labels = np.array(dict[b'labels'], dtype=np.int64)
+    
+    return data, labels
+
+
+def read_cifar (batch_dir) :
+    data_batches = []
+    label_batches = []
+    
+    for i in range(1,6) :
+        batch_name = f'data_batch_{i}'
+        batch_path = os.path.join(batch_dir, batch_name)
+        data, labels = read_cifar_batch(batch_path)
+        data_batches.append(data)
+        label_batches.append(labels)
+
+    test_batch_filename = 'test_batch'
+    test_batch_path = os.path.join(batch_dir, test_batch_filename)
+    data_test, labels_test = read_cifar_batch(test_batch_path)
+    data_batches.append(data_test)
+    label_batches.append(labels_test)
+
+    data = np.concatenate(data_batches, axis=0)
+    labels = np.concatenate(label_batches, axis=0)
+
+    return data, labels
+
+
+def split_dataset (data, labels, split) :
+    if len(data) != len(labels) :
+        raise ValueError("data and labels should have the same size in the first dimension")
+
+    if split< 0 or split > 1 :
+        raise ValueError("Split ratio should be between 0 and 1")
+
+    data_size = len(data)
+    shuffled_indexes = np.random.permutation(data_size)
+    train_set_size = int(data_size*split)
+
+    data_train = []
+    data_test = []
+    labels_train = []
+    labels_test = []
+
+    for i in range (train_set_size+1) :
+        index = shuffled_indexes[i]
+        data_train.append(data[index])
+        labels_train.append(labels[index])
+
+    for j in range (train_set_size+1, data_size) :
+        index = shuffled_indexes[j]
+        data_test.append(data[index])
+        labels_test.append(labels[index])
+
+    data_train = np.array(data_train)
+    data_test = np.array(data_test)
+    
+    return data_train, labels_train, data_test, labels_test
+
+
+if __name__ == "__main__":
+    batch = read_cifar('data/cifar-10-batches-py/')
+    data = batch[0]
+    labels = batch[1]
+    split = 0.9
+    data_train, labels_train, data_test, labels_test = split_dataset(data, labels, split)
+    print(len(data_train), len(data_test), len(data_train)+len(data_test), len(labels_train)+len(labels_test))
+
+
+