import pickle import numpy as np import os from sklearn.model_selection import train_test_split def read_cifar_batch(batch): with open(batch, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') data = dict[b'data'] labels = dict[b'labels'] print(dict[b'batch_label']) return data, labels def read_cifar(path): batches_list = os.listdir(path) data, labels = [], [] for batch in batches_list: if(batch == 'batches.meta' or batch == 'readme.html'): continue data_batch, labels_batch = read_cifar_batch(path + '/' + batch) data.append(data_batch) labels.append(labels_batch) return np.array(data, dtype=np.float32).reshape((60000, 3072)), np.array(labels, dtype=np.int64).reshape(-1) def split_dataset(data, labels, split): data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=1-split, shuffle=True) return data_train, data_test, labels_train, labels_test