Skip to content
Snippets Groups Projects
Commit 8a923d8a authored by Dubray Chloe's avatar Dubray Chloe
Browse files

Upload New File

parent 9880893d
No related branches found
No related tags found
No related merge requests found
import numpy as np
import pickle
import os
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
def read_cifar_batch(file):
dict = unpickle(file)
data = np.array(dict[b'data'], dtype=np.float32)
labels = np.array(dict[b'labels'], dtype=np.int64)
return data, labels
def read_cifar (batch_dir) :
data_batches = []
label_batches = []
for i in range(1,6) :
batch_name = f'data_batch_{i}'
batch_path = os.path.join(batch_dir, batch_name)
data, labels = read_cifar_batch(batch_path)
data_batches.append(data)
label_batches.append(labels)
test_batch_filename = 'test_batch'
test_batch_path = os.path.join(batch_dir, test_batch_filename)
data_test, labels_test = read_cifar_batch(test_batch_path)
data_batches.append(data_test)
label_batches.append(labels_test)
data = np.concatenate(data_batches, axis=0)
labels = np.concatenate(label_batches, axis=0)
return data, labels
def split_dataset (data, labels, split) :
if len(data) != len(labels) :
raise ValueError("data and labels should have the same size in the first dimension")
if split< 0 or split > 1 :
raise ValueError("Split ratio should be between 0 and 1")
data_size = len(data)
shuffled_indexes = np.random.permutation(data_size)
train_set_size = int(data_size*split)
data_train = []
data_test = []
labels_train = []
labels_test = []
for i in range (train_set_size+1) :
index = shuffled_indexes[i]
data_train.append(data[index])
labels_train.append(labels[index])
for j in range (train_set_size+1, data_size) :
index = shuffled_indexes[j]
data_test.append(data[index])
labels_test.append(labels[index])
data_train = np.array(data_train)
data_test = np.array(data_test)
return data_train, labels_train, data_test, labels_test
if __name__ == "__main__":
batch = read_cifar('data/cifar-10-batches-py/')
data = batch[0]
labels = batch[1]
split = 0.9
data_train, labels_train, data_test, labels_test = split_dataset(data, labels, split)
print(len(data_train), len(data_test), len(data_train)+len(data_test), len(labels_train)+len(labels_test))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment