diff --git a/read_cifar.py b/read_cifar.py index 827588b3cfb53e055ce95cd9b6f2c7bf0afeaaad..3d2709e6111a9048dd7cf2534722178d8295dd97 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -1,6 +1,7 @@ import pickle import numpy as np + def unpickle(file): with open(file, 'rb') as fo: batch = pickle.load(fo, encoding='bytes') @@ -19,7 +20,12 @@ def read_cifar(folder_path): labels = np.concatenate((labels, read_cifar_batch(folder_path + "/data_batch_" + str(i))[1])) return data, labels - +def split_dataset(data, labels, split): + index = int(split * len(data)) + data_train, data_test = np.split(data, index) + labels_train, labels_test = np.split(labels, index) + return data_train, labels_train, data_test, labels_test + if __name__ == "__main__":