Skip to content
Snippets Groups Projects
Commit 99930e54 authored by Denis Thomas's avatar Denis Thomas
Browse files

Upload New File

parent e99e5888
No related branches found
No related tags found
No related merge requests found
import numpy as np
import pickle
import os
from typing import Tuple, List
import matplotlib.pyplot as plt
# Question 2: Fonction pour lire un batch CIFAR
def read_cifar_batch(batch_path: str) -> Tuple[np.ndarray, np.ndarray]:
with open(batch_path, 'rb') as f:
batch = pickle.load(f, encoding='latin1')
data = batch['data'].astype(np.float32)
labels = np.array(batch['labels'], dtype=np.int64)
return data, labels
# Question 3: Fonction pour lire tous les batches CIFAR
def read_cifar(cifar_dir: str) -> Tuple[np.ndarray, np.ndarray]:
# On initialise les listes qui vont stocké nos batchs
data_batches = []
label_batches = []
# On commence par lire les 5 batchs de train qu'on a telecharger et placé dans data/cifar-10-batches-py
for i in range(1, 6): # les noms vont de 1 a 5
batch_path = os.path.join(cifar_dir, f'data_batch_{i}')
# On utilise la fonction du haut pour lire un batch singulier
data, labels = read_cifar_batch(batch_path)
data_batches.append(data)
label_batches.append(labels)
# On lit maintenant le batch de test
test_path = os.path.join(cifar_dir, 'test_batch')
test_data, test_labels = read_cifar_batch(test_path)
data_batches.append(test_data)
label_batches.append(test_labels)
# On finit par concatener tous les batchs, on utilise directement la methode numpy concatenate
all_data = np.concatenate(data_batches)
all_labels = np.concatenate(label_batches)
return all_data, all_labels
# Question 4: Fonction pour split nos datas
def split_dataset(data: np.ndarray, labels: np.ndarray, split: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
# On shuffle nos datas
# On commence par creer une liste d'indices aléatoirement placé, avec la méthode permutation
permutation = np.random.permutation(len(data))
# On peut ensuite directement permutter le tous
data = data[permutation]
labels = labels[permutation]
# On calcul l'indice auquel on décide de split les datas entre train et test
split_idx = int(len(data) * split)
# On split les datas
data_train = data[:split_idx]
labels_train = labels[:split_idx]
data_test = data[split_idx:]
labels_test = labels[split_idx:]
return data_train, labels_train, data_test, labels_test
if __name__ == "__main__":
cifar_dir = "data/cifar-10-batches-py"
# On lit cifar
all_data, all_labels = read_cifar(cifar_dir)
print(f"data shape: {all_data.shape}")
print(f"labels shape: {all_labels.shape}")
# On split cifar
data_train, labels_train, data_test, labels_test = split_dataset(all_data, all_labels, split=0.9)
print(f"Training shape: {data_train.shape}")
print(f"Testing shape: {data_test.shape}")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment