diff --git a/read_cifar.py b/read_cifar.py index c2a03180445946798ed09c380dc89a8170abaf81..c89dd245076d157081aa83b1b7c12e56459160c7 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -11,16 +11,9 @@ import pickle -# batch.meta -#{b'num_cases_per_batch': 10000, b'label_names': [b'airplane', b'automobile', b'bird', b'cat', b'deer', b'dog', b'frog', b'horse', b'ship', b'truck'], b'num_vis': 3072} - def read_cifar_batch(file): with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') - # keys = [b'batch_label', - # b'labels', - # b'data', - # b'filenames'] return (np.array(dict[b'data']).astype('float32'), np.array(dict[b'labels']).astype('int64')) def read_cifar(path): @@ -47,10 +40,6 @@ def read_cifar(path): def split_dataset(data, labels, split): - """ - Cette fonction divise l'ensemble de notre data en training data set et testing data set. - - """ n = data.shape[0] indices = np.random.permutation(n) train_idx, test_idx = indices[:int(split*n)], indices[int(split*n):] @@ -62,6 +51,7 @@ def split_dataset(data, labels, split): + if __name__ == "__main__": path = r'data\cifar-10-batches-py\data_batch_1' main_path = r'data\cifar-10-batches-py'