import pickle import numpy as np def unpickle(file): with open(file, 'rb') as fo: batch = pickle.load(fo, encoding='bytes') return batch def read_cifar_batch(path): batch = unpickle(path) data = batch[b'data'] labels = batch[b'labels'] return np.float32(data), np.int64(labels) def read_cifar(folder_path): data, labels = read_cifar_batch("./data/cifar-10-batches-py/test_batch") for i in range(1,5): data = np.concatenate((data, read_cifar_batch(folder_path + "/data_batch_" + str(i))[0])) labels = np.concatenate((labels, read_cifar_batch(folder_path + "/data_batch_" + str(i))[1])) return data, labels def split_dataset(data, labels, split): index = int(split * len(data)) data_train, data_test = np.split(data, index) labels_train, labels_test = np.split(labels, index) return data_train, labels_train, data_test, labels_test if __name__ == "__main__": data, labels = read_cifar("./data/cifar-10-batches-py") print(data) print(labels)