Skip to content
Snippets Groups Projects
Commit 319c8de1 authored by Cart Milan's avatar Cart Milan
Browse files

Part 1 : Prepare the CIFAR dataset

parent 49087cb0
Branches
No related tags found
No related merge requests found
File added
import numpy as np
import pickle
from sklearn.model_selection import train_test_split
import pandas as pd
import pickle
......@@ -12,4 +16,37 @@ def read_cifar_batch(batch_path):
return data, labels
print(read_cifar_batch('Data/cifar-10-batches-py/data_batch_2'))
\ No newline at end of file
def read_cifar(path):
data = []
labels = []
#Add the 5 batches
for i in range(1,6):
data_temp, labels_temp = read_cifar_batch(f'{path}/data_batch_{i}')
data.append(data_temp)
labels.append(labels_temp)
#Add the test batches
data_temp, labels_temp = read_cifar_batch(f'{path}/test_batch')
data.append(data_temp)
labels.append(labels_temp)
#Concatenate all the batches to create a big one
data = np.concatenate(data, axis = 0)
labels = np.concatenate(labels, axis = 0)
return(data, labels)
def split_dataset(data, labels, split):
X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=(1 - split), random_state=0)
return(X_train, X_test, y_train, y_test)
if __name__== '__main__':
data, labels = read_cifar_batch('Data/cifar-10-batches-py/data_batch_1')
data, labels = read_cifar('/Users/milancart/Documents/GitHub/image-classification/Data/cifar-10-batches-py')
X_train, X_test, y_train, y_test = split_dataset(data, labels, 0.8)
print(X_train, X_test, y_train, y_test)
\ No newline at end of file
import read_cifar
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment