From 7cb02fb3c09baf55473527c58787c7ea58902ea4 Mon Sep 17 00:00:00 2001
From: lucile <lucile.audard@ecl20.ec-lyon.fr>
Date: Mon, 23 Oct 2023 09:22:43 +0200
Subject: [PATCH] Update read_cifar.py

---
 read_cifar.py | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/read_cifar.py b/read_cifar.py
index 827588b..3d2709e 100644
--- a/read_cifar.py
+++ b/read_cifar.py
@@ -1,6 +1,7 @@
 import pickle
 import numpy as np
 
+
 def unpickle(file):
     with open(file, 'rb') as fo:
         batch = pickle.load(fo, encoding='bytes')
@@ -19,7 +20,12 @@ def read_cifar(folder_path):
         labels = np.concatenate((labels, read_cifar_batch(folder_path + "/data_batch_" + str(i))[1]))
     return data, labels
 
-
+def split_dataset(data, labels, split):
+    index = int(split * len(data))
+    data_train, data_test = np.split(data, index)
+    labels_train, labels_test = np.split(labels, index)
+    return data_train, labels_train, data_test, labels_test
+    
 
 
 if __name__ == "__main__":
-- 
GitLab