Skip to content
Snippets Groups Projects
Commit 4b2d6c8e authored by Saidi Aya's avatar Saidi Aya
Browse files

Update read_cifar.py

parent eec797e0
No related branches found
No related tags found
No related merge requests found
......@@ -47,14 +47,20 @@ def read_cifar(directory_path):
def split_dataset(data,labels,split):
#This function splits the dataset into a training set and a test set
#It takes as parameter data and labels, two arrays that have the same size in the first dimension. And a split, a float between 0 and 1 which determines the split factor of the training set with respect to the test set.
#split -- the split factor
#data -- the whole data (all the batches including the test batch)
#labels -- the labels associated to the data
data_train=[]
data_test=[]
labels_train=[]
labels_test=[]
for i in range(0,len(data)):
mask=np.random.randn(len(data[i])) <= split
data_train.append(data[i][mask])
labels_train.append(labels[i][mask])
data_test.append(data[i][~mask])
labels_test.append(labels[i][~mask])
labels=labels.reshape(data.shape[0],1)
# Stack our Data and labels
con = np.hstack((data, labels))
k=int(split*con.shape[0])
# Shuffle all our Data
np.random.shuffle(con)
# Train
data_train=con[:k,:-1]
labels_train=con[:k,-1]
# Test
data_test=con[k:,:-1]
labels_test=con[k:,-1]
return data_train,labels_train,data_test,labels_test
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment