import numpy as np import pickle import os def unpickle(file): with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') return dict def read_cifar_batch(file): dict = unpickle(file) data = np.array(dict[b'data'], dtype=np.float32) labels = np.array(dict[b'labels'], dtype=np.int64) return data, labels def read_cifar (batch_dir) : data_batches = [] label_batches = [] for i in range(1,6) : batch_name = f'data_batch_{i}' batch_path = os.path.join(batch_dir, batch_name) data, labels = read_cifar_batch(batch_path) data_batches.append(data) label_batches.append(labels) test_batch_filename = 'test_batch' test_batch_path = os.path.join(batch_dir, test_batch_filename) data_test, labels_test = read_cifar_batch(test_batch_path) data_batches.append(data_test) label_batches.append(labels_test) data = np.concatenate(data_batches, axis=0) labels = np.concatenate(label_batches, axis=0) return data, labels def split_dataset (data, labels, split) : if len(data) != len(labels) : raise ValueError("data and labels should have the same size in the first dimension") if split< 0 or split > 1 : raise ValueError("Split ratio should be between 0 and 1") data_size = len(data) shuffled_indexes = np.random.permutation(data_size) train_set_size = int(data_size*split) data_train = [] data_test = [] labels_train = [] labels_test = [] for i in range (train_set_size+1) : index = shuffled_indexes[i] data_train.append(data[index]) labels_train.append(labels[index]) for j in range (train_set_size+1, data_size) : index = shuffled_indexes[j] data_test.append(data[index]) labels_test.append(labels[index]) data_train = np.array(data_train) data_test = np.array(data_test) return data_train, labels_train, data_test, labels_test if __name__ == "__main__": batch = read_cifar('data/cifar-10-batches-py/') data = batch[0] labels = batch[1] split = 0.9 data_train, labels_train, data_test, labels_test = split_dataset(data, labels, split) print(len(data_train), len(data_test), len(data_train)+len(data_test), len(labels_train)+len(labels_test))