import numpy as np import os from sklearn.model_selection import train_test_split def unpickle(file): import pickle with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') return dict def read_cifar_batch(file): dict = unpickle(file) data = dict[b'data'].astype(np.float32) labels = np.array(dict[b'labels'], dtype=np.int64) labels = labels.reshape(labels.shape[0]) return data, labels def read_cifar(path): print('Reading data from disk') data_batches = ["data_batch_" + str(i) for i in range(1, 6)] + ['test_batch'] flag = True for db in data_batches: data, labels = read_cifar_batch(os.path.join(path, db)) if flag: DATA = data LABELS = labels flag = False else: DATA = np.concatenate((DATA, data), axis=0, dtype=np.float32) LABELS = np.concatenate((LABELS, labels), axis=-1, dtype=np.int64) return DATA, LABELS def split_dataset(data, labels, split): print(f"Splitting data into train/test with split={split}") n = data.shape[0] indices = np.random.permutation(n) train_idx, test_idx = indices[:int(split*n)], indices[int(split*n):] data_train, data_test = data[train_idx,:].astype(np.float32), data[test_idx,:].astype(np.float32) labels_train, labels_test = labels[train_idx].astype(np.int64), labels[test_idx].astype(np.int64) # data_train, data_test, labels_train, labels_test = train_test_split(data, labels,test_size=split, shuffle= True) return data_train, labels_train, data_test, labels_test