Skip to content
Snippets Groups Projects
Commit df8a6a82 authored by Audard Lucile's avatar Audard Lucile
Browse files

Update read_cifar.py

parent 6c0a1bf6
Branches
No related tags found
No related merge requests found
import pickle
import numpy as np
from sklearn.model_selection import train_test_split
import random
def unpickle(file):
......@@ -7,29 +9,54 @@ def unpickle(file):
batch = pickle.load(fo, encoding='bytes')
return batch
def read_cifar_batch(path):
batch = unpickle(path)
data = batch[b'data']
labels = batch[b'labels']
return np.float32(data), np.int64(labels)
def read_cifar(folder_path):
# Get the test batch
data, labels = read_cifar_batch("./data/cifar-10-batches-py/test_batch")
# Concatenate with the 5 data batches
for i in range(1,5):
data = np.concatenate((data, read_cifar_batch(folder_path + "/data_batch_" + str(i))[0]))
labels = np.concatenate((labels, read_cifar_batch(folder_path + "/data_batch_" + str(i))[1]))
np.append(data, read_cifar_batch(folder_path + "/data_batch_" + str(i))[0])
np.append(labels, read_cifar_batch(folder_path + "/data_batch_" + str(i))[1])
return data, labels
def split_dataset(data, labels, split):
# Determination of an index to split the data
index = int(split * len(data))
data_train, data_test = np.split(data, index)
labels_train, labels_test = np.split(labels, index)
# Split the data on the index
tableau_combine = list(zip(data, labels))
random.shuffle(tableau_combine)
data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=1-split, random_state=1)
# data_train, data_test = np.split(data, [index])
# labels_train, labels_test = np.split(labels, [index])
return data_train, labels_train, data_test, labels_test
if __name__ == "__main__":
# Extraction of the data from Cifar database
data, labels = read_cifar("./data/cifar-10-batches-py")
print(data)
print(labels)
# Formatting the data into training and testing sets
split = 0.21
data_train, labels_train, data_test, labels_test = split_dataset(data, labels, split)
print(data_train)
print(labels_train)
print(data_test)
print(labels_test)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment