diff --git a/read_cifar.py b/read_cifar.py
index 00346bf9cec587fce3e39a3f02324e4562c7f794..66fde3deffa4fe9d1c1c30137ded02b1a61b4e4f 100644
--- a/read_cifar.py
+++ b/read_cifar.py
@@ -1,5 +1,7 @@
 import pickle
 import numpy as np
+from sklearn.model_selection import train_test_split 
+import random
 
 
 def unpickle(file):
@@ -7,29 +9,54 @@ def unpickle(file):
         batch = pickle.load(fo, encoding='bytes')
     return batch
 
+
 def read_cifar_batch(path):
     batch = unpickle(path)
     data = batch[b'data']
     labels = batch[b'labels']
     return np.float32(data), np.int64(labels)
 
+
 def read_cifar(folder_path):
+    
+    # Get the test batch
     data, labels = read_cifar_batch("./data/cifar-10-batches-py/test_batch")
+    
+    # Concatenate with the 5 data batches
     for i in range(1,5):
-        data = np.concatenate((data, read_cifar_batch(folder_path + "/data_batch_" + str(i))[0]))
-        labels = np.concatenate((labels, read_cifar_batch(folder_path + "/data_batch_" + str(i))[1]))
+        np.append(data, read_cifar_batch(folder_path + "/data_batch_" + str(i))[0])
+        np.append(labels, read_cifar_batch(folder_path + "/data_batch_" + str(i))[1])
+        
     return data, labels
 
+
 def split_dataset(data, labels, split):
+    
+    # Determination of an index to split the data
     index = int(split * len(data))
-    data_train, data_test = np.split(data, index)
-    labels_train, labels_test = np.split(labels, index)
+    
+    # Split the data on the index
+    tableau_combine = list(zip(data, labels))
+    random.shuffle(tableau_combine)
+    data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=1-split, random_state=1)
+    # data_train, data_test = np.split(data, [index])
+    # labels_train, labels_test = np.split(labels, [index])
     return data_train, labels_train, data_test, labels_test
     
 
 
 if __name__ == "__main__":
+    
+    # Extraction of the data from Cifar database
     data, labels = read_cifar("./data/cifar-10-batches-py")
     print(data)
     print(labels)
+    
+    # Formatting the data into training and testing sets
+    split = 0.21
+    data_train, labels_train, data_test, labels_test = split_dataset(data, labels, split)
+    print(data_train)
+    print(labels_train)
+    print(data_test)
+    print(labels_test)