diff --git a/read_cifar.py b/read_cifar.py index 7b24a902e2ef69c1e7c485517612e3c3606eac61..d527d1399941b0f8ce5a641633d018feab1a417f 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -43,18 +43,24 @@ def read_cifar(directory_path): data.append( Xt ) labels.append( Yt ) return data,labels - -def split_dataset(data, labels, split): + +def split_dataset(data,labels,split): #This function splits the dataset into a training set and a test set #It takes as parameter data and labels, two arrays that have the same size in the first dimension. And a split, a float between 0 and 1 which determines the split factor of the training set with respect to the test set. + #split -- the split factor + #data -- the whole data (all the batches including the test batch) + #labels -- the labels associated to the data data_train=[] - data_test=[] - labels_train=[] - labels_test=[] - for i in range(0,len(data)): - mask=np.random.randn(len(data[i])) <= split - data_train.append(data[i][mask]) - labels_train.append(labels[i][mask]) - data_test.append(data[i][~mask]) - labels_test.append(labels[i][~mask]) - return data_train, labels_train, data_test, labels_test + labels=labels.reshape(data.shape[0],1) + # Stack our Data and labels + con = np.hstack((data, labels)) + k=int(split*con.shape[0]) + # Shuffle all our Data + np.random.shuffle(con) + # Train + data_train=con[:k,:-1] + labels_train=con[:k,-1] + # Test + data_test=con[k:,:-1] + labels_test=con[k:,-1] + return data_train,labels_train,data_test,labels_test