diff --git a/read_cifar.py b/read_cifar.py
index 41207b4209c7453761b6f8515c0218a5bc8ba3d4..362ba3d688ac0ffefe9dc0d10dc90cde0f171a1b 100644
--- a/read_cifar.py
+++ b/read_cifar.py
@@ -26,9 +26,9 @@ def read_cifar_batch (batch_path):
         data_dict = load_pickle(bp)
         data = data_dict['data']
         labels = data_dict['labels']
-        data = data.reshape(10000,3072)
+        data = data.reshape(len(data),len(data[0]))
         data = data.astype('f')  #data must be np.float32 array.
-        labels = np.array(labels, dtype='i')  #labels must be np.int64 array.
+        labels = np.array(labels, dtype='int64')  #labels must be np.int64 array.
         return data, labels
 
 def read_cifar(directory_path):
@@ -48,7 +48,14 @@ def read_cifar(directory_path):
 def split_dataset(data, labels, split):
     #This function splits the dataset into a training set and a test set
     #It takes as parameter data and labels, two arrays that have the same size in the first dimension. And a split, a float between 0 and 1 which determines the split factor of the training set with respect to the test set.
-    data_train, labels_train = shuffle(data.sample(frac=split, random_state=25),)
-    data_test = shuffle(data.drop(data_train.index))
-
-    return data_train, data_test, labels_train, labels_test
+    data_train=[]
+    data_test=[]
+    labels_train=[]
+    labels_test=[]
+    for i in range(0,len(data)):
+        mask=np.random.randn(len(data[i])) <= split
+        data_train.append(data[i][mask])
+        labels_train.append(labels[i][mask])
+        data_test.append(data[i][~mask])
+        labels_test.append(labels[i][~mask])
+    return data_train, labels_train, data_test, labels_test