Select Git revision
read_cifar.py 1.57 KiB
import numpy as np
import pickle
def read_cifar_batch(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
data = np.array(dict[b'data']).astype('float32')
labels = np.array(dict[b'labels']).astype('int64')
return data, labels
#vect1= read_cifar_batch("data/cifar-10-batches-py/data_batch_1")
#print(vect1)
def read_cifar(directory):
data = []
labels = []
for i in range(1,6):
data_v, labels_v = read_cifar_batch(f'{directory}/data_batch_{i}')
data.append(data_v)
labels.append(labels_v)
data_v, labels_v = read_cifar_batch(f'{directory}/test_batch')
data.append(data_v)
labels.append(labels_v)
data = np.concatenate(data, axis = 0)
labels = np.concatenate(labels, axis = 0)
return(data, labels)
def split_dataset(data, labels, split):
data_size = data.shape[0]
train_size = int(data_size * split)
indices = np.arange(data_size)
np.random.shuffle(indices)
indices_train = indices[:train_size]
indices_test = indices[train_size:]
data_train = data[indices_train]
labels_train = labels[indices_train]
data_test = data[indices_test]
labels_test = labels[indices_test]
return(data_train, labels_train, data_test, labels_test)
if __name__ == "__main__":
#vect1= read_cifar_batch("data/cifar-10-batches-py/data_batch_1")
#print(vect1)
#vect2= read_cifar("data/cifar-10-batches-py")
#print(vect2)
pair = read_cifar("data/cifar-10-batches-py")
vect3= split_dataset(pair[0], pair[1], 0.6)
print(vect3)