Skip to content
Snippets Groups Projects
Commit 69db049d authored by Saidi Aya's avatar Saidi Aya
Browse files

Update read_cifar.py

parent f9ba8a36
Branches
No related tags found
No related merge requests found
#Importing the useful libraries we will be using later
import numpy as np
from six.moves import cPickle as pickle
import os
import platform
from sklearn.utils import shuffle
#Defining the classes contained in the CIFAR-10 dataset
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
img_rows, img_cols = 32, 32 #The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes
input_shape = (img_rows, img_cols, 3)
def load_pickle(f):
#This function takes a file name as an input and loads it in order to work on it later.
version = platform.python_version_tuple()
#The loading of the file depends on the version of python we are using.
if version[0] == '2':
return pickle.load(f)
elif version[0] == '3':
return pickle.load(f, encoding='latin1')
raise ValueError("invalid python version: {}".format(version))
def read_cifar_batch (batch_path):
#This function takes as parameter the path of a single batch as a string, and returns a matrix data of size (batch_size x data_size) and a a vector labels of size batch_size.
with open(batch_path, 'rb') as bp:
data_dict = load_pickle(bp)
data = data_dict['data']
labels = data_dict['labels']
data = data.reshape(10000,3072)
data = data.astype('f') #data must be np.float32 array.
labels = np.array(labels, dtype='i') #labels must be np.int64 array.
return data, labels
def read_cifar(directory_path):
#This function takes as parameter the path of the directory containing the six batches and returns a matrix data a vector lables of size batch_size
data=[]
labels=[]
for b in range(1,6):
file = os.path.join(directory_path, 'data_batch_%d'% (b, ))
Xd, Yd = read_cifar_batch(file)
data.append( Xd )
labels.append( Yd )
Xt, Yt = read_cifar_batch(os.path.join(directory_path, 'test_batch'))
data.append( Xt )
labels.append( Yt )
return data,labels
def split_dataset(data, labels, split):
#This function splits the dataset into a training set and a test set
#It takes as parameter data and labels, two arrays that have the same size in the first dimension. And a split, a float between 0 and 1 which determines the split factor of the training set with respect to the test set.
data_train, labels_train = shuffle(data.sample(frac=split, random_state=25),)
data_test = shuffle(data.drop(data_train.index))
return data_train, data_test, labels_train, labels_test
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment