From 781f0b13ff056f308e5fd6044ec43a3fb6b0737b Mon Sep 17 00:00:00 2001 From: oscarchaufour <101994223+oscarchaufour@users.noreply.github.com> Date: Fri, 20 Oct 2023 17:13:53 +0200 Subject: [PATCH] Update read_cifar.py cifar extraction --- read_cifar.py | 42 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/read_cifar.py b/read_cifar.py index a905f14..4fe8169 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -5,8 +5,44 @@ Created on Fri Oct 20 16:04:49 2023 @author: oscar """ -def read_cifar_batch() : - return +import os +import numpy as np +import pickle + +import pickle + +def read_cifar_batch(batch_path): + with open(batch_path, 'rb') as f: + batch = pickle.load(f, encoding='bytes') + + data = np.array(batch.get(b'data')) + labels = np.array(batch.get(b'labels')) + + + return data, labels + +def read_cifar (batch_dir) : + data_batches = [] + label_batches = [] + + for i in range(1,6) : + batch_filename = f'data_batch_{i}' + batch_path = os.path.join(batch_dir, batch_filename) + data, labels = read_cifar_batch(batch_path) + data_batches.append(data) + label_batches.append(labels) + + test_batch_filename = 'test_batch' + test_batch_path = os.path.join(batch_dir, test_batch_filename) + data_test, labels_test = read_cifar_batch(test_batch_path) + data_batches.append(data_test) + label_batches.append(labels_test) + + data = np.concatenate(data_batches, axis=0) + labels = np.concatenate(label_batches, axis=0) + + return data, labels if __name__ == "__main__": - pass + file = "./data/cifar-10-python/data_batch_1" + read_cifar_batch(file) -- GitLab