From 0386569dee68b71a191f0ade681b6a9886835e58 Mon Sep 17 00:00:00 2001 From: BaptisteBrd <75663738+BaptisteBrd@users.noreply.github.com> Date: Thu, 9 Nov 2023 12:21:00 +0100 Subject: [PATCH] tri cifar done --- read_cifar.py | 48 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/read_cifar.py b/read_cifar.py index 2aeb021..6e0c7cd 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -1,5 +1,6 @@ import numpy as np import pickle + def read_cifar_batch(file): with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') @@ -12,24 +13,49 @@ def read_cifar_batch(file): #print(vect1) def read_cifar(directory): - all_data = [] - all_labels = [] + data = [] + labels = [] for i in range(1,6): data_v, labels_v = read_cifar_batch(f'{directory}/data_batch_{i}') - all_data.append(data_v) - all_labels.append(labels_v) + data.append(data_v) + labels.append(labels_v) data_v, labels_v = read_cifar_batch(f'{directory}/test_batch') - all_data.append(data_v) - all_labels.append(labels_v) + data.append(data_v) + labels.append(labels_v) + + data = np.concatenate(data, axis = 0) + labels = np.concatenate(labels, axis = 0) + + return(data, labels) + +def split_dataset(data, labels, split): + + data_size = data.shape[0] + train_size = int(data_size * split) + indices = np.arange(data_size) + np.random.shuffle(indices) + + indices_train = indices[:train_size] + indices_test = indices[train_size:] + data_train = data[indices_train] + labels_train = labels[indices_train] + data_test = data[indices_test] + labels_test = labels[indices_test] + + return(data_train, labels_train, data_test, labels_test) - all_data = np.concatenate(all_data, axis = 0) - all_labels = np.concatenate(all_labels, axis = 0) + - return(all_data, all_labels) +if __name__ == "__main__": + #vect1= read_cifar_batch("data/cifar-10-batches-py/data_batch_1") + #print(vect1) + #vect2= read_cifar("data/cifar-10-batches-py") + #print(vect2) -#vect2= read_cifar("data/cifar-10-batches-py") -#print(vect2) + pair = read_cifar("data/cifar-10-batches-py") + vect3= split_dataset(pair[0], pair[1], 0.6) + print(vect3) -- GitLab