diff --git a/read_cifar.py b/read_cifar.py
index 827588b3cfb53e055ce95cd9b6f2c7bf0afeaaad..3d2709e6111a9048dd7cf2534722178d8295dd97 100644
--- a/read_cifar.py
+++ b/read_cifar.py
@@ -1,6 +1,7 @@
 import pickle
 import numpy as np
 
+
 def unpickle(file):
     with open(file, 'rb') as fo:
         batch = pickle.load(fo, encoding='bytes')
@@ -19,7 +20,12 @@ def read_cifar(folder_path):
         labels = np.concatenate((labels, read_cifar_batch(folder_path + "/data_batch_" + str(i))[1]))
     return data, labels
 
-
+def split_dataset(data, labels, split):
+    index = int(split * len(data))
+    data_train, data_test = np.split(data, index)
+    labels_train, labels_test = np.split(labels, index)
+    return data_train, labels_train, data_test, labels_test
+    
 
 
 if __name__ == "__main__":