diff --git a/read_cifar.py b/read_cifar.py index 41207b4209c7453761b6f8515c0218a5bc8ba3d4..362ba3d688ac0ffefe9dc0d10dc90cde0f171a1b 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -26,9 +26,9 @@ def read_cifar_batch (batch_path): data_dict = load_pickle(bp) data = data_dict['data'] labels = data_dict['labels'] - data = data.reshape(10000,3072) + data = data.reshape(len(data),len(data[0])) data = data.astype('f') #data must be np.float32 array. - labels = np.array(labels, dtype='i') #labels must be np.int64 array. + labels = np.array(labels, dtype='int64') #labels must be np.int64 array. return data, labels def read_cifar(directory_path): @@ -48,7 +48,14 @@ def read_cifar(directory_path): def split_dataset(data, labels, split): #This function splits the dataset into a training set and a test set #It takes as parameter data and labels, two arrays that have the same size in the first dimension. And a split, a float between 0 and 1 which determines the split factor of the training set with respect to the test set. - data_train, labels_train = shuffle(data.sample(frac=split, random_state=25),) - data_test = shuffle(data.drop(data_train.index)) - - return data_train, data_test, labels_train, labels_test + data_train=[] + data_test=[] + labels_train=[] + labels_test=[] + for i in range(0,len(data)): + mask=np.random.randn(len(data[i])) <= split + data_train.append(data[i][mask]) + labels_train.append(labels[i][mask]) + data_test.append(data[i][~mask]) + labels_test.append(labels[i][~mask]) + return data_train, labels_train, data_test, labels_test