Skip to content
Snippets Groups Projects
Commit 1a0d1140 authored by Chauvin Hugo's avatar Chauvin Hugo
Browse files

Update read_cifar.py

parent 2107659c
No related branches found
No related tags found
No related merge requests found
import numpy as np
import os
import pickle
import random
def unpickle(file):
import pickle
with open(file, 'rb') as f:
dict = pickle.load(f, encoding='bytes')
return dict
def read_cifar_batch(batch_path) :
with open(batch_path, 'rb') as file:
# On unpickle le batch
batch = pickle.load(file, encoding='bytes')
# Extraction de data et labels
data = np.array(batch[b'data'], dtype=np.float32)/255.0
labels = np.array(batch[b'labels'], dtype = np.int64)
return data, labels
def read_cifar(batch_dir):
data_batches = []
label_batches = []
# Itération sur les batches
for file_name in os.listdir(batch_dir):
if file_name.startswith("data_batch") or file_name.startswith("test_batch") :
batch_path = os.path.join(batch_dir, file_name)
data, labels = read_cifar_batch(batch_path)
data_batches.append(data)
label_batches.append(labels)
# On combine data et labels depuis tous les batches
data = np.concatenate(data_batches, axis=0)
labels = np.concatenate(label_batches, axis=0)
return data, labels
def split_dataset(data, labels, split):
# On vérifie la bonne dimension de data et labels
if data.shape[0] != labels.shape[0]:
return OSError("data et labels doivent avoir le même nombre de lignes !")
# On détermine la taille des data train et test
train_size = round(data.shape[0]*split)
# On shuffle les data et labels
shuffle_index = [i for i in range(data.shape[0])]
# On extirpe les data/labels train et test
data_train = data[shuffle_index][:train_size]
labels_train = np.array([[labels[i]] for i in shuffle_index])[:train_size]
data_test = data[shuffle_index][train_size:]
labels_test = np.array([[labels[i]] for i in shuffle_index])[train_size:]
return data_train, labels_train, data_test, labels_test
if __name__ == "__main__" :
data_folder = 'C:\\Users\\hugol\\Desktop\\Centrale Lyon\\Centrale Lyon 4A\\Informatique\\Machine Learning\\BE1\\cifar-10-batches-py'
batch_filename = 'data_batch_1'
batch_path = os.path.join(data_folder, batch_filename)
data, labels = read_cifar_batch(batch_path)
print("Data shape:", data.shape)
print("Labels shape:", labels.shape)
data_all, labels_all = read_cifar(data_folder)
print("Data shape:", data_all.shape)
print("Labels shape:", labels_all.shape)
data_train, labels_train, data_test, labels_test = split_dataset(data_all, labels_all, 0.9)
print("Data train shape:", data_train.shape)
print("Labels train shape:", labels_train.shape)
print("Data test shape:", data_test.shape)
print("Labels test shape:", labels_test.shape)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment