Skip to content
Snippets Groups Projects
Commit 0386569d authored by BaptisteBrd's avatar BaptisteBrd
Browse files

tri cifar done

parent 0e07e0d3
No related tags found
No related merge requests found
import numpy as np
import pickle
def read_cifar_batch(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
......@@ -12,24 +13,49 @@ def read_cifar_batch(file):
#print(vect1)
def read_cifar(directory):
all_data = []
all_labels = []
data = []
labels = []
for i in range(1,6):
data_v, labels_v = read_cifar_batch(f'{directory}/data_batch_{i}')
all_data.append(data_v)
all_labels.append(labels_v)
data.append(data_v)
labels.append(labels_v)
data_v, labels_v = read_cifar_batch(f'{directory}/test_batch')
all_data.append(data_v)
all_labels.append(labels_v)
data.append(data_v)
labels.append(labels_v)
data = np.concatenate(data, axis = 0)
labels = np.concatenate(labels, axis = 0)
return(data, labels)
def split_dataset(data, labels, split):
data_size = data.shape[0]
train_size = int(data_size * split)
indices = np.arange(data_size)
np.random.shuffle(indices)
all_data = np.concatenate(all_data, axis = 0)
all_labels = np.concatenate(all_labels, axis = 0)
indices_train = indices[:train_size]
indices_test = indices[train_size:]
data_train = data[indices_train]
labels_train = labels[indices_train]
data_test = data[indices_test]
labels_test = labels[indices_test]
return(all_data, all_labels)
return(data_train, labels_train, data_test, labels_test)
if __name__ == "__main__":
#vect1= read_cifar_batch("data/cifar-10-batches-py/data_batch_1")
#print(vect1)
#vect2= read_cifar("data/cifar-10-batches-py")
#print(vect2)
pair = read_cifar("data/cifar-10-batches-py")
vect3= split_dataset(pair[0], pair[1], 0.6)
print(vect3)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment