diff --git a/read_cifar.py b/read_cifar.py
index 7b24a902e2ef69c1e7c485517612e3c3606eac61..d527d1399941b0f8ce5a641633d018feab1a417f 100644
--- a/read_cifar.py
+++ b/read_cifar.py
@@ -43,18 +43,24 @@ def read_cifar(directory_path):
     data.append( Xt )
     labels.append( Yt )
     return data,labels
-
-def split_dataset(data, labels, split):
+    
+def split_dataset(data,labels,split):
     #This function splits the dataset into a training set and a test set
     #It takes as parameter data and labels, two arrays that have the same size in the first dimension. And a split, a float between 0 and 1 which determines the split factor of the training set with respect to the test set.
+    #split -- the split factor
+    #data -- the whole data (all the batches including the test batch)
+    #labels -- the labels associated to the data
     data_train=[]
-    data_test=[]
-    labels_train=[]
-    labels_test=[]
-    for i in range(0,len(data)):
-        mask=np.random.randn(len(data[i])) <= split
-        data_train.append(data[i][mask])
-        labels_train.append(labels[i][mask])
-        data_test.append(data[i][~mask])
-        labels_test.append(labels[i][~mask])
-    return data_train, labels_train, data_test, labels_test
+    labels=labels.reshape(data.shape[0],1)
+    # Stack our Data and labels
+    con = np.hstack((data, labels))
+    k=int(split*con.shape[0])
+    # Shuffle all our Data
+    np.random.shuffle(con)
+    # Train
+    data_train=con[:k,:-1]
+    labels_train=con[:k,-1]
+    # Test
+    data_test=con[k:,:-1]
+    labels_test=con[k:,-1]
+    return data_train,labels_train,data_test,labels_test