From 7cb02fb3c09baf55473527c58787c7ea58902ea4 Mon Sep 17 00:00:00 2001 From: lucile <lucile.audard@ecl20.ec-lyon.fr> Date: Mon, 23 Oct 2023 09:22:43 +0200 Subject: [PATCH] Update read_cifar.py --- read_cifar.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/read_cifar.py b/read_cifar.py index 827588b..3d2709e 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -1,6 +1,7 @@ import pickle import numpy as np + def unpickle(file): with open(file, 'rb') as fo: batch = pickle.load(fo, encoding='bytes') @@ -19,7 +20,12 @@ def read_cifar(folder_path): labels = np.concatenate((labels, read_cifar_batch(folder_path + "/data_batch_" + str(i))[1])) return data, labels - +def split_dataset(data, labels, split): + index = int(split * len(data)) + data_train, data_test = np.split(data, index) + labels_train, labels_test = np.split(labels, index) + return data_train, labels_train, data_test, labels_test + if __name__ == "__main__": -- GitLab