diff --git a/read_cifar.py b/read_cifar.py index a905f14762169e9598eeb9103bda740bd279cb1a..4fe81697207c59b00d30f9a1ee5254f0cd187d59 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -5,8 +5,44 @@ Created on Fri Oct 20 16:04:49 2023 @author: oscar """ -def read_cifar_batch() : - return +import os +import numpy as np +import pickle + +import pickle + +def read_cifar_batch(batch_path): + with open(batch_path, 'rb') as f: + batch = pickle.load(f, encoding='bytes') + + data = np.array(batch.get(b'data')) + labels = np.array(batch.get(b'labels')) + + + return data, labels + +def read_cifar (batch_dir) : + data_batches = [] + label_batches = [] + + for i in range(1,6) : + batch_filename = f'data_batch_{i}' + batch_path = os.path.join(batch_dir, batch_filename) + data, labels = read_cifar_batch(batch_path) + data_batches.append(data) + label_batches.append(labels) + + test_batch_filename = 'test_batch' + test_batch_path = os.path.join(batch_dir, test_batch_filename) + data_test, labels_test = read_cifar_batch(test_batch_path) + data_batches.append(data_test) + label_batches.append(labels_test) + + data = np.concatenate(data_batches, axis=0) + labels = np.concatenate(label_batches, axis=0) + + return data, labels if __name__ == "__main__": - pass + file = "./data/cifar-10-python/data_batch_1" + read_cifar_batch(file)