From 8047c92f712a910c77b57d9350391fe4773bccad Mon Sep 17 00:00:00 2001 From: lucile <lucile.audard@ecl20.ec-lyon.fr> Date: Sun, 22 Oct 2023 19:59:07 +0200 Subject: [PATCH] Update read_cifar.py --- read_cifar.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/read_cifar.py b/read_cifar.py index 5d0736d..827588b 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -11,9 +11,19 @@ def read_cifar_batch(path): data = batch[b'data'] labels = batch[b'labels'] return np.float32(data), np.int64(labels) - - + +def read_cifar(folder_path): + data, labels = read_cifar_batch("./data/cifar-10-python.tar/cifar-10-batches-py~/cifar-10-batches-py/test_batch") + for i in range(1,5): + data = np.concatenate((data, read_cifar_batch(folder_path + "/data_batch_" + str(i))[0])) + labels = np.concatenate((labels, read_cifar_batch(folder_path + "/data_batch_" + str(i))[1])) + return data, labels + + + + if __name__ == "__main__": - v1, v2 = read_cifar_batch("./data/cifar-10-batches-py/cifar-10-batches-py~/cifar-10-batches-py/data_batch_1") - print(v1) - print(v2) + data, labels = read_cifar("./data/cifar-10-python.tar/cifar-10-batches-py~/cifar-10-batches-py") + print(data) + print(labels) + -- GitLab