diff --git a/read_cifar.py b/read_cifar.py new file mode 100644 index 0000000000000000000000000000000000000000..116997425444f4baf5650ad8abe1ca554adf2e6d --- /dev/null +++ b/read_cifar.py @@ -0,0 +1,85 @@ +import numpy as np +import pickle +import os + + +def unpickle(file): + with open(file, 'rb') as fo: + dict = pickle.load(fo, encoding='bytes') + return dict + + +def read_cifar_batch(file): + dict = unpickle(file) + + data = np.array(dict[b'data'], dtype=np.float32) + labels = np.array(dict[b'labels'], dtype=np.int64) + + return data, labels + + +def read_cifar (batch_dir) : + data_batches = [] + label_batches = [] + + for i in range(1,6) : + batch_name = f'data_batch_{i}' + batch_path = os.path.join(batch_dir, batch_name) + data, labels = read_cifar_batch(batch_path) + data_batches.append(data) + label_batches.append(labels) + + test_batch_filename = 'test_batch' + test_batch_path = os.path.join(batch_dir, test_batch_filename) + data_test, labels_test = read_cifar_batch(test_batch_path) + data_batches.append(data_test) + label_batches.append(labels_test) + + data = np.concatenate(data_batches, axis=0) + labels = np.concatenate(label_batches, axis=0) + + return data, labels + + +def split_dataset (data, labels, split) : + if len(data) != len(labels) : + raise ValueError("data and labels should have the same size in the first dimension") + + if split< 0 or split > 1 : + raise ValueError("Split ratio should be between 0 and 1") + + data_size = len(data) + shuffled_indexes = np.random.permutation(data_size) + train_set_size = int(data_size*split) + + data_train = [] + data_test = [] + labels_train = [] + labels_test = [] + + for i in range (train_set_size+1) : + index = shuffled_indexes[i] + data_train.append(data[index]) + labels_train.append(labels[index]) + + for j in range (train_set_size+1, data_size) : + index = shuffled_indexes[j] + data_test.append(data[index]) + labels_test.append(labels[index]) + + data_train = np.array(data_train) + data_test = np.array(data_test) + + return data_train, labels_train, data_test, labels_test + + +if __name__ == "__main__": + batch = read_cifar('data/cifar-10-batches-py/') + data = batch[0] + labels = batch[1] + split = 0.9 + data_train, labels_train, data_test, labels_test = split_dataset(data, labels, split) + print(len(data_train), len(data_test), len(data_train)+len(data_test), len(labels_train)+len(labels_test)) + + +