diff --git a/read_cifar.py b/read_cifar.py index 8b137891791fe96927ad78e64b0aad7bded08bdc..74e622e4be1af51cfc9fbaf3cec35f74ef8d5c23 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -1 +1,81 @@ +import numpy as np +import os +import pickle +import random +def unpickle(file): + import pickle + with open(file, 'rb') as f: + dict = pickle.load(f, encoding='bytes') + return dict + +def read_cifar_batch(batch_path) : + with open(batch_path, 'rb') as file: + # On unpickle le batch + batch = pickle.load(file, encoding='bytes') + + # Extraction de data et labels + data = np.array(batch[b'data'], dtype=np.float32)/255.0 + labels = np.array(batch[b'labels'], dtype = np.int64) + + return data, labels + +def read_cifar(batch_dir): + data_batches = [] + label_batches = [] + + # Itération sur les batches + for file_name in os.listdir(batch_dir): + if file_name.startswith("data_batch") or file_name.startswith("test_batch") : + batch_path = os.path.join(batch_dir, file_name) + data, labels = read_cifar_batch(batch_path) + data_batches.append(data) + label_batches.append(labels) + + # On combine data et labels depuis tous les batches + data = np.concatenate(data_batches, axis=0) + labels = np.concatenate(label_batches, axis=0) + + return data, labels + +def split_dataset(data, labels, split): + # On vérifie la bonne dimension de data et labels + if data.shape[0] != labels.shape[0]: + return OSError("data et labels doivent avoir le même nombre de lignes !") + + # On détermine la taille des data train et test + train_size = round(data.shape[0]*split) + + # On shuffle les data et labels + shuffle_index = [i for i in range(data.shape[0])] + + # On extirpe les data/labels train et test + data_train = data[shuffle_index][:train_size] + labels_train = np.array([[labels[i]] for i in shuffle_index])[:train_size] + data_test = data[shuffle_index][train_size:] + labels_test = np.array([[labels[i]] for i in shuffle_index])[train_size:] + + return data_train, labels_train, data_test, labels_test + + + + +if __name__ == "__main__" : + data_folder = 'C:\\Users\\hugol\\Desktop\\Centrale Lyon\\Centrale Lyon 4A\\Informatique\\Machine Learning\\BE1\\cifar-10-batches-py' + batch_filename = 'data_batch_1' + + batch_path = os.path.join(data_folder, batch_filename) + data, labels = read_cifar_batch(batch_path) + print("Data shape:", data.shape) + print("Labels shape:", labels.shape) + + data_all, labels_all = read_cifar(data_folder) + print("Data shape:", data_all.shape) + print("Labels shape:", labels_all.shape) + + data_train, labels_train, data_test, labels_test = split_dataset(data_all, labels_all, 0.9) + print("Data train shape:", data_train.shape) + print("Labels train shape:", labels_train.shape) + print("Data test shape:", data_test.shape) + print("Labels test shape:", labels_test.shape) +