From c281dc0a0adf8dd9873b4aef378f5786cf6fb16e Mon Sep 17 00:00:00 2001
From: Elkhadri Doha <doha.elkhadri@etu.ec-lyon.fr>
Date: Fri, 10 Nov 2023 06:34:30 +0000
Subject: [PATCH] Update read_cifar_test.py

---
 tests/read_cifar_test.py | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/tests/read_cifar_test.py b/tests/read_cifar_test.py
index 7693798..ffde3e1 100644
--- a/tests/read_cifar_test.py
+++ b/tests/read_cifar_test.py
@@ -1,8 +1,8 @@
 from read_cifar import read_cifar
 
-def test_read_cifar():
-    data, labels = read_cifar(r'C:\Users\etulyon1\OneDrive\Desktop\Deep_Learning1\image-classification\data')
-    assert data.shape == (60000, 3072)
-    assert labels.shape == (60000,)
-    assert data.dtype == 'float32'
-    assert labels.dtype == 'int64'
+def read_cifar_batch(BATCH_PATH):
+    with open(BATCH_PATH, "rb") as f:
+        d = pickle.load(f, encoding="bytes")
+        data = d[b"data"].astype(np.float32)
+        labels = np.array([d[b"labels"]]).astype(np.int64)
+        return data, labels
-- 
GitLab