Skip to content
Snippets Groups Projects
Commit 54bf1931 authored by selalimi's avatar selalimi
Browse files

Add read_cifar.py

parent c8a8999c
Branches
No related tags found
No related merge requests found
import numpy as np
import os
import pickle
# Commentaire global expliquant le but du code
'''Here is the code to prepare the CIFAR dataset, create a function to read CIFAR batches, and split the dataset into training and testing sets:'''
# Fonction read_cifar_batch :
'''
Arguments :
Le chemin d'un seul batch en tant que chaîne de caractères.
Returns :
Une matrice de données de taille (taille_du_batch , taille_des_données)
Un vecteur d'étiquettes (labels) de taille (taille_du_batch)
'''
def read_cifar_batch(batch_path):
# Ouvre le fichier du batch et charge les données
with open(batch_path, 'rb') as f:
batch_dict = pickle.load(f, encoding='bytes')
# Convertit les données en float32
data = batch_dict[b'data'].astype(np.float32)
# Convertit les étiquettes en int64
labels = np.array(batch_dict[b'labels'], dtype=np.int64)
return data, labels
# Fonction read_cifar :
'''
*** lit le chemin du répertoire contenant tous les lots (y compris test_batch)***
*Arguments :
Le chemin du répertoire contenant les six lots (cinq data_batch et un test_batch) en tant que chaîne de caractères.
*Returns :
-Une matrice de données de taille (taille_du_lot , taille_des_données).
-Un vecteur d'étiquettes(labels) de taille (taille_du_lot).
'''
def read_cifar(folder):
# Liste des noms de fichiers de batch
batch_files = ["data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5", "test_batch"]
data_list, labels_list = [], []
# Boucle sur les fichiers de batch
for batch_file in batch_files:
path = os.path.join(folder, batch_file)
# Appelle read_cifar_batch pour lire chaque batch
data, labels = read_cifar_batch(path)
data_list.append(data)
labels_list.append(labels)
# Combine les données de tous les batches
data = np.vstack(data_list)
# Combine les étiquettes de tous les batches
labels = np.hstack(labels_list)
return data, labels
# Fonction pour diviser les données en ensembles d'entraînement et de test :
'''
*Arguments :
-data et labels, deux tableaux de même taille dans la première dimension.
-split, un nombre flottant compris entre 0 et 1, qui détermine le facteur de répartition de l'ensemble d'entraînement par rapport à l'ensemble de test.
*Renvoie :
-data_train : les données d'entraînement.
-labels_train : les étiquettes correspondantes.
-data_test : les données de test.
-labels_test : les étiquettes correspondantes.
'''
def split_dataset(data, labels, split):
# Vérifie que le ratio de division est valide
assert 0 < split < 1
n = len(data)
# Mélange les indices des données
indices = np.random.permutation(n)
split_index = int(split * n)
# Sépare les indices pour l'ensemble d'entraînement et de test
train_indices = indices[:split_index]
test_indices = indices[split_index:]
# Sépare les données et étiquettes en ensembles d'entraînement et de test
data_train = data[train_indices]
labels_train = labels[train_indices]
data_test = data[test_indices]
labels_test = labels[test_indices]
return data_train, labels_train, data_test, labels_test
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment