diff --git a/read_cifar.py b/read_cifar.py index 00346bf9cec587fce3e39a3f02324e4562c7f794..66fde3deffa4fe9d1c1c30137ded02b1a61b4e4f 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -1,5 +1,7 @@ import pickle import numpy as np +from sklearn.model_selection import train_test_split +import random def unpickle(file): @@ -7,29 +9,54 @@ def unpickle(file): batch = pickle.load(fo, encoding='bytes') return batch + def read_cifar_batch(path): batch = unpickle(path) data = batch[b'data'] labels = batch[b'labels'] return np.float32(data), np.int64(labels) + def read_cifar(folder_path): + + # Get the test batch data, labels = read_cifar_batch("./data/cifar-10-batches-py/test_batch") + + # Concatenate with the 5 data batches for i in range(1,5): - data = np.concatenate((data, read_cifar_batch(folder_path + "/data_batch_" + str(i))[0])) - labels = np.concatenate((labels, read_cifar_batch(folder_path + "/data_batch_" + str(i))[1])) + np.append(data, read_cifar_batch(folder_path + "/data_batch_" + str(i))[0]) + np.append(labels, read_cifar_batch(folder_path + "/data_batch_" + str(i))[1]) + return data, labels + def split_dataset(data, labels, split): + + # Determination of an index to split the data index = int(split * len(data)) - data_train, data_test = np.split(data, index) - labels_train, labels_test = np.split(labels, index) + + # Split the data on the index + tableau_combine = list(zip(data, labels)) + random.shuffle(tableau_combine) + data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=1-split, random_state=1) + # data_train, data_test = np.split(data, [index]) + # labels_train, labels_test = np.split(labels, [index]) return data_train, labels_train, data_test, labels_test if __name__ == "__main__": + + # Extraction of the data from Cifar database data, labels = read_cifar("./data/cifar-10-batches-py") print(data) print(labels) + + # Formatting the data into training and testing sets + split = 0.21 + data_train, labels_train, data_test, labels_test = split_dataset(data, labels, split) + print(data_train) + print(labels_train) + print(data_test) + print(labels_test)