diff --git a/read_cifar.py b/read_cifar.py new file mode 100644 index 0000000000000000000000000000000000000000..5b01659ed4c4edbd505fe6825a2d8bb3d3ec9c6b --- /dev/null +++ b/read_cifar.py @@ -0,0 +1,88 @@ +import numpy as np +import os +import pickle + +# Commentaire global expliquant le but du code +'''Here is the code to prepare the CIFAR dataset, create a function to read CIFAR batches, and split the dataset into training and testing sets:''' + +# Fonction read_cifar_batch : +''' +Arguments : +Le chemin d'un seul batch en tant que chaîne de caractères. + +Returns : +Une matrice de données de taille (taille_du_batch , taille_des_données) +Un vecteur d'étiquettes (labels) de taille (taille_du_batch) +''' + +def read_cifar_batch(batch_path): + # Ouvre le fichier du batch et charge les données + with open(batch_path, 'rb') as f: + batch_dict = pickle.load(f, encoding='bytes') + # Convertit les données en float32 + data = batch_dict[b'data'].astype(np.float32) + # Convertit les étiquettes en int64 + labels = np.array(batch_dict[b'labels'], dtype=np.int64) + return data, labels + +# Fonction read_cifar : +''' +*** lit le chemin du répertoire contenant tous les lots (y compris test_batch)*** + +*Arguments : +Le chemin du répertoire contenant les six lots (cinq data_batch et un test_batch) en tant que chaîne de caractères. + +*Returns : +-Une matrice de données de taille (taille_du_lot , taille_des_données). +-Un vecteur d'étiquettes(labels) de taille (taille_du_lot). +''' +def read_cifar(folder): + # Liste des noms de fichiers de batch + batch_files = ["data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5", "test_batch"] + data_list, labels_list = [], [] + + # Boucle sur les fichiers de batch + for batch_file in batch_files: + path = os.path.join(folder, batch_file) + # Appelle read_cifar_batch pour lire chaque batch + data, labels = read_cifar_batch(path) + data_list.append(data) + labels_list.append(labels) + + # Combine les données de tous les batches + data = np.vstack(data_list) + # Combine les étiquettes de tous les batches + labels = np.hstack(labels_list) + + return data, labels + +# Fonction pour diviser les données en ensembles d'entraînement et de test : +''' +*Arguments : +-data et labels, deux tableaux de même taille dans la première dimension. +-split, un nombre flottant compris entre 0 et 1, qui détermine le facteur de répartition de l'ensemble d'entraînement par rapport à l'ensemble de test. + +*Renvoie : +-data_train : les données d'entraînement. +-labels_train : les étiquettes correspondantes. +-data_test : les données de test. +-labels_test : les étiquettes correspondantes. +''' +def split_dataset(data, labels, split): + # Vérifie que le ratio de division est valide + assert 0 < split < 1 + n = len(data) + # Mélange les indices des données + indices = np.random.permutation(n) + split_index = int(split * n) + # Sépare les indices pour l'ensemble d'entraînement et de test + train_indices = indices[:split_index] + test_indices = indices[split_index:] + # Sépare les données et étiquettes en ensembles d'entraînement et de test + data_train = data[train_indices] + labels_train = labels[train_indices] + data_test = data[test_indices] + labels_test = labels[test_indices] + return data_train, labels_train, data_test, labels_test + +