import pickle import numpy as np import os from sklearn.model_selection import train_test_split def read_cifar_batch(batch): with open(batch, 'rb') as file: dict = pickle.load(file, encoding='bytes') batch_data = dict[b'data'] batch_labels = dict[b'labels'] return batch_data, batch_labels def read_cifar(path): batches_list = os.listdir(path) data, labels = [], [] for batch in batches_list: if(batch == 'batches.meta' or batch == 'readme.html'): continue data_batch, labels_batch = read_cifar_batch(path + '/' + batch) data.append(data_batch) labels.append(labels_batch) data= np.array(data, dtype=np.float32).reshape((60000, 3072)) labels=np.array(labels, dtype=np.int64).reshape(-1) return data, labels def split_dataset(data, labels, split_factor): data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=1-split_factor, shuffle=True) return data_train, data_test, labels_train, labels_test