From 0e07e0d31258bf48c9b4e9d0abd4e56ffc05e821 Mon Sep 17 00:00:00 2001 From: BaptisteBrd <75663738+BaptisteBrd@users.noreply.github.com> Date: Thu, 9 Nov 2023 00:50:09 +0100 Subject: [PATCH] fonction read cifar --- read_cifar.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/read_cifar.py b/read_cifar.py index abba3fa..2aeb021 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -7,8 +7,29 @@ def read_cifar_batch(file): labels = np.array(dict[b'labels']).astype('int64') return data, labels -vect1= read_cifar_batch("data/cifar-10-batches-py/data_batch_1") +#vect1= read_cifar_batch("data/cifar-10-batches-py/data_batch_1") #print(vect1) -#def read_cifar \ No newline at end of file +def read_cifar(directory): + all_data = [] + all_labels = [] + + for i in range(1,6): + data_v, labels_v = read_cifar_batch(f'{directory}/data_batch_{i}') + all_data.append(data_v) + all_labels.append(labels_v) + + data_v, labels_v = read_cifar_batch(f'{directory}/test_batch') + all_data.append(data_v) + all_labels.append(labels_v) + + all_data = np.concatenate(all_data, axis = 0) + all_labels = np.concatenate(all_labels, axis = 0) + + return(all_data, all_labels) + + +#vect2= read_cifar("data/cifar-10-batches-py") +#print(vect2) + -- GitLab