Skip to content
Snippets Groups Projects
Commit 7cb02fb3 authored by Audard Lucile's avatar Audard Lucile
Browse files

Update read_cifar.py

parent 8047c92f
No related branches found
No related tags found
No related merge requests found
import pickle
import numpy as np
def unpickle(file):
with open(file, 'rb') as fo:
batch = pickle.load(fo, encoding='bytes')
......@@ -19,7 +20,12 @@ def read_cifar(folder_path):
labels = np.concatenate((labels, read_cifar_batch(folder_path + "/data_batch_" + str(i))[1]))
return data, labels
def split_dataset(data, labels, split):
index = int(split * len(data))
data_train, data_test = np.split(data, index)
labels_train, labels_test = np.split(labels, index)
return data_train, labels_train, data_test, labels_test
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment