import pickle
import numpy as np
import os
from sklearn.model_selection import train_test_split

def read_cifar_batch(batch):

    with open(batch, 'rb') as file:
        dict = pickle.load(file, encoding='bytes')
        batch_data = dict[b'data']
        batch_labels = dict[b'labels']

    return batch_data, batch_labels

def read_cifar(path):

    batches_list = os.listdir(path)
    data, labels = [], []

    for batch in batches_list:
        if(batch == 'batches.meta' or batch == 'readme.html'):
            continue
        data_batch, labels_batch = read_cifar_batch(path + '/' + batch)
        data.append(data_batch)
        labels.append(labels_batch)

    data= np.array(data, dtype=np.float32).reshape((60000, 3072))
    labels=np.array(labels, dtype=np.int64).reshape(-1)

    return data, labels

def split_dataset(data, labels, split_factor):

    data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=1-split_factor, shuffle=True)

    return data_train, data_test, labels_train, labels_test