Skip to content
Snippets Groups Projects
Commit 14e42b7b authored by Saidi Aya's avatar Saidi Aya
Browse files

Update read_cifar.py

parent 7278d8dc
No related branches found
No related tags found
No related merge requests found
......@@ -26,9 +26,9 @@ def read_cifar_batch (batch_path):
data_dict = load_pickle(bp)
data = data_dict['data']
labels = data_dict['labels']
data = data.reshape(10000,3072)
data = data.reshape(len(data),len(data[0]))
data = data.astype('f') #data must be np.float32 array.
labels = np.array(labels, dtype='i') #labels must be np.int64 array.
labels = np.array(labels, dtype='int64') #labels must be np.int64 array.
return data, labels
def read_cifar(directory_path):
......@@ -48,7 +48,14 @@ def read_cifar(directory_path):
def split_dataset(data, labels, split):
#This function splits the dataset into a training set and a test set
#It takes as parameter data and labels, two arrays that have the same size in the first dimension. And a split, a float between 0 and 1 which determines the split factor of the training set with respect to the test set.
data_train, labels_train = shuffle(data.sample(frac=split, random_state=25),)
data_test = shuffle(data.drop(data_train.index))
return data_train, data_test, labels_train, labels_test
data_train=[]
data_test=[]
labels_train=[]
labels_test=[]
for i in range(0,len(data)):
mask=np.random.randn(len(data[i])) <= split
data_train.append(data[i][mask])
labels_train.append(labels[i][mask])
data_test.append(data[i][~mask])
labels_test.append(labels[i][~mask])
return data_train, labels_train, data_test, labels_test
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment