From c281dc0a0adf8dd9873b4aef378f5786cf6fb16e Mon Sep 17 00:00:00 2001 From: Elkhadri Doha <doha.elkhadri@etu.ec-lyon.fr> Date: Fri, 10 Nov 2023 06:34:30 +0000 Subject: [PATCH] Update read_cifar_test.py --- tests/read_cifar_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/read_cifar_test.py b/tests/read_cifar_test.py index 7693798..ffde3e1 100644 --- a/tests/read_cifar_test.py +++ b/tests/read_cifar_test.py @@ -1,8 +1,8 @@ from read_cifar import read_cifar -def test_read_cifar(): - data, labels = read_cifar(r'C:\Users\etulyon1\OneDrive\Desktop\Deep_Learning1\image-classification\data') - assert data.shape == (60000, 3072) - assert labels.shape == (60000,) - assert data.dtype == 'float32' - assert labels.dtype == 'int64' +def read_cifar_batch(BATCH_PATH): + with open(BATCH_PATH, "rb") as f: + d = pickle.load(f, encoding="bytes") + data = d[b"data"].astype(np.float32) + labels = np.array([d[b"labels"]]).astype(np.int64) + return data, labels -- GitLab