diff --git a/read_cifar.py b/read_cifar.py
index 2aeb021bed6dfd1767498bd1196c82e0eef67a2b..6e0c7cd6d67feb628273c9c8332a1c6029d8e4e4 100644
--- a/read_cifar.py
+++ b/read_cifar.py
@@ -1,5 +1,6 @@
 import numpy as np
 import pickle
+
 def read_cifar_batch(file):
     with open(file, 'rb') as fo:
         dict = pickle.load(fo, encoding='bytes')
@@ -12,24 +13,49 @@ def read_cifar_batch(file):
 #print(vect1)
 
 def read_cifar(directory):
-    all_data = []
-    all_labels = []
+    data = []
+    labels = []
 
     for i in range(1,6):
         data_v, labels_v = read_cifar_batch(f'{directory}/data_batch_{i}')
-        all_data.append(data_v)
-        all_labels.append(labels_v)
+        data.append(data_v)
+        labels.append(labels_v)
 
     data_v, labels_v = read_cifar_batch(f'{directory}/test_batch')
-    all_data.append(data_v)
-    all_labels.append(labels_v)
+    data.append(data_v)
+    labels.append(labels_v)
+
+    data = np.concatenate(data, axis = 0)
+    labels = np.concatenate(labels, axis = 0)
+
+    return(data, labels)
+
+def split_dataset(data, labels, split):
+
+    data_size = data.shape[0]
+    train_size = int(data_size * split)
+    indices = np.arange(data_size)
+    np.random.shuffle(indices)
+
+    indices_train = indices[:train_size]
+    indices_test = indices[train_size:]
+    data_train = data[indices_train]
+    labels_train = labels[indices_train]
+    data_test = data[indices_test]
+    labels_test = labels[indices_test]
+    
+    return(data_train, labels_train, data_test, labels_test)
 
-    all_data = np.concatenate(all_data, axis = 0)
-    all_labels = np.concatenate(all_labels, axis = 0)
+    
 
-    return(all_data, all_labels)
 
+if __name__ == "__main__":
+    #vect1= read_cifar_batch("data/cifar-10-batches-py/data_batch_1")
+    #print(vect1)
+    #vect2= read_cifar("data/cifar-10-batches-py")
+    #print(vect2)
 
-#vect2= read_cifar("data/cifar-10-batches-py")
-#print(vect2)
+    pair = read_cifar("data/cifar-10-batches-py")
 
+    vect3= split_dataset(pair[0], pair[1], 0.6)
+    print(vect3)