Skip to content
Snippets Groups Projects
Commit f98665d7 authored by Saidi Aya's avatar Saidi Aya
Browse files

Update read_cifar.py

parent 7c0ab631
No related branches found
No related tags found
No related merge requests found
......@@ -2,26 +2,14 @@
import numpy as np
from six.moves import cPickle as pickle
import os
import platform
#Defining the classes contained in the CIFAR-10 dataset
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
img_rows, img_cols = 32, 32 #The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes
input_shape = (img_rows, img_cols, 3)
def load_pickle(f):
#This function takes a file name as an input and loads it in order to work on it later.
version = platform.python_version_tuple()
#The loading of the file depends on the version of python we are using.
if version[0] == '2':
return pickle.load(f)
elif version[0] == '3':
return pickle.load(f, encoding='latin1')
raise ValueError("invalid python version: {}".format(version))
import random
def unpickle(file):
'''loads the data dictionnary.'''
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
def read_cifar_batch (batch_path):
#This function takes as parameter the path of a single batch as a string, and returns a matrix data of size (batch_size x data_size) and a a vector labels of size batch_size.
with open(batch_path, 'rb') as bp:
data_dict = load_pickle(bp)
data = data_dict['data']
labels = data_dict['labels']
......@@ -32,16 +20,22 @@ def read_cifar_batch (batch_path):
def read_cifar(directory_path):
#This function takes as parameter the path of the directory containing the six batches and returns a matrix data a vector lables of size batch_size
data=[]
labels=[]
for b in range(1,6):
file = os.path.join(directory_path, 'data_batch_%d'% (b, ))
Xd, Yd = read_cifar_batch(file)
data.append( Xd )
labels.append( Yd )
Xt, Yt = read_cifar_batch(os.path.join(directory_path, 'test_batch'))
data.append( Xt )
labels.append( Yt )
files=['/data_batch_1','/data_batch_2','/data_batch_3','/data_batch_4','/data_batch_5','/test_batch']
A=10000
N=60000
P=3072
data=np.empty((N,P),dtype=np.float)
labels=np.empty(A,dtype=np.int64)
for i in range(len(files)):
fichier=directory_path+files[i]
data_dict=unpickle(fichier)
M=data_dict[b'data']
L=data_dict[b'labels']
L=np.array(L)
data=np.vstack((X,M))
labels=np.hstack((Y,L))
data=X[N:2*N,]
labels=Y[A:,]
return data,labels
def split_dataset(data,labels,split):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment