From 14e42b7be5a4fa41fdaa51ead530a4884efcc079 Mon Sep 17 00:00:00 2001 From: Aya SAIDI <aya.saidi@auditeur.ec-lyon.fr> Date: Fri, 21 Oct 2022 17:05:17 +0100 Subject: [PATCH] Update read_cifar.py --- read_cifar.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/read_cifar.py b/read_cifar.py index 41207b4..362ba3d 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -26,9 +26,9 @@ def read_cifar_batch (batch_path): data_dict = load_pickle(bp) data = data_dict['data'] labels = data_dict['labels'] - data = data.reshape(10000,3072) + data = data.reshape(len(data),len(data[0])) data = data.astype('f') #data must be np.float32 array. - labels = np.array(labels, dtype='i') #labels must be np.int64 array. + labels = np.array(labels, dtype='int64') #labels must be np.int64 array. return data, labels def read_cifar(directory_path): @@ -48,7 +48,14 @@ def read_cifar(directory_path): def split_dataset(data, labels, split): #This function splits the dataset into a training set and a test set #It takes as parameter data and labels, two arrays that have the same size in the first dimension. And a split, a float between 0 and 1 which determines the split factor of the training set with respect to the test set. - data_train, labels_train = shuffle(data.sample(frac=split, random_state=25),) - data_test = shuffle(data.drop(data_train.index)) - - return data_train, data_test, labels_train, labels_test + data_train=[] + data_test=[] + labels_train=[] + labels_test=[] + for i in range(0,len(data)): + mask=np.random.randn(len(data[i])) <= split + data_train.append(data[i][mask]) + labels_train.append(labels[i][mask]) + data_test.append(data[i][~mask]) + labels_test.append(labels[i][~mask]) + return data_train, labels_train, data_test, labels_test -- GitLab