Select Git revision
arbre-parcours-largeur.py
Forked from
Vuillemot Romain / INF-TC1
Source project has a limited visibility.
read_cifar.py 2.09 KiB
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 7 08:45:09 2024
@author: danjo
"""
import numpy as np
import pickle
# batch.meta
#{b'num_cases_per_batch': 10000, b'label_names': [b'airplane', b'automobile', b'bird', b'cat', b'deer', b'dog', b'frog', b'horse', b'ship', b'truck'], b'num_vis': 3072}
def read_cifar_batch(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
# keys = [b'batch_label',
# b'labels',
# b'data',
# b'filenames']
return (np.array(dict[b'data']).astype('float32'), np.array(dict[b'labels']).astype('int64'))
def read_cifar(path):
data = []
labels = []
#Add the 5 batches
for i in range(1,6):
data_temp, labels_temps = read_cifar_batch(f'{path}/data_batch_{i}')
data.append(data_temp)
labels.append(labels_temps)
#Add the test batches
data_temp, labels_temps = read_cifar_batch(f'{path}/test_batch')
data.append(data_temp)
labels.append(labels_temps)
#Concatenate all the batches to create a big one
data = np.concatenate(data, axis = 0)
labels = np.concatenate(labels, axis = 0)
return (data, labels)
def split_dataset(data, labels, split):
"""
Cette fonction divise l'ensemble de notre data en training data set et testing data set.
"""
n = data.shape[0]
indices = np.random.permutation(n)
train_idx, test_idx = indices[:int(split*n)], indices[int(split*n):]
data_train, data_test = data[train_idx,:].astype(np.float32), data[test_idx,:].astype(np.float32)
labels_train, labels_test = labels[train_idx].astype(np.int64), labels[test_idx].astype(np.int64)
return data_train, data_test, labels_train, labels_test
if __name__ == "__main__":
path = r'data\cifar-10-batches-py\data_batch_1'
main_path = r'data\cifar-10-batches-py'
data, labels = read_cifar_batch(path)
data, labels = read_cifar(main_path)
data_train, data_test, labels_train, labels_test = split_dataset(data, labels, 0.9)
#print(X_train, X_test, y_train, y_test)
#print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)