diff --git a/read_cifar.py b/read_cifar.py new file mode 100644 index 0000000000000000000000000000000000000000..23fea88ad48678a3d0fb225b446b9ea238b23e6e --- /dev/null +++ b/read_cifar.py @@ -0,0 +1,74 @@ +import numpy as np +import pickle +import os +from typing import Tuple, List +import matplotlib.pyplot as plt + +# Question 2: Fonction pour lire un batch CIFAR +def read_cifar_batch(batch_path: str) -> Tuple[np.ndarray, np.ndarray]: + with open(batch_path, 'rb') as f: + batch = pickle.load(f, encoding='latin1') + + data = batch['data'].astype(np.float32) + labels = np.array(batch['labels'], dtype=np.int64) + + return data, labels + +# Question 3: Fonction pour lire tous les batches CIFAR +def read_cifar(cifar_dir: str) -> Tuple[np.ndarray, np.ndarray]: + # On initialise les listes qui vont stocké nos batchs + data_batches = [] + label_batches = [] + + # On commence par lire les 5 batchs de train qu'on a telecharger et placé dans data/cifar-10-batches-py + for i in range(1, 6): # les noms vont de 1 a 5 + batch_path = os.path.join(cifar_dir, f'data_batch_{i}') + # On utilise la fonction du haut pour lire un batch singulier + data, labels = read_cifar_batch(batch_path) + data_batches.append(data) + label_batches.append(labels) + + # On lit maintenant le batch de test + test_path = os.path.join(cifar_dir, 'test_batch') + test_data, test_labels = read_cifar_batch(test_path) + data_batches.append(test_data) + label_batches.append(test_labels) + + # On finit par concatener tous les batchs, on utilise directement la methode numpy concatenate + all_data = np.concatenate(data_batches) + all_labels = np.concatenate(label_batches) + + return all_data, all_labels + +# Question 4: Fonction pour split nos datas +def split_dataset(data: np.ndarray, labels: np.ndarray, split: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + # On shuffle nos datas + # On commence par creer une liste d'indices aléatoirement placé, avec la méthode permutation + permutation = np.random.permutation(len(data)) + # On peut ensuite directement permutter le tous + data = data[permutation] + labels = labels[permutation] + + # On calcul l'indice auquel on décide de split les datas entre train et test + split_idx = int(len(data) * split) + + # On split les datas + data_train = data[:split_idx] + labels_train = labels[:split_idx] + data_test = data[split_idx:] + labels_test = labels[split_idx:] + + return data_train, labels_train, data_test, labels_test + +if __name__ == "__main__": + cifar_dir = "data/cifar-10-batches-py" + + # On lit cifar + all_data, all_labels = read_cifar(cifar_dir) + print(f"data shape: {all_data.shape}") + print(f"labels shape: {all_labels.shape}") + + # On split cifar + data_train, labels_train, data_test, labels_test = split_dataset(all_data, all_labels, split=0.9) + print(f"Training shape: {data_train.shape}") + print(f"Testing shape: {data_test.shape}") \ No newline at end of file