import pickle
import numpy as np
import os
from sklearn.model_selection import train_test_split

def read_cifar_batch(batch):
    with open(batch, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
        data = dict[b'data']
        labels = dict[b'labels']
        print(dict[b'batch_label'])
    return data, labels

def read_cifar(path):
    batches_list = os.listdir(path)
    data, labels = [], []
    for batch in batches_list:
        if(batch == 'batches.meta' or batch == 'readme.html'):
            continue
        data_batch, labels_batch = read_cifar_batch(path + '/' + batch)
        data.append(data_batch)
        labels.append(labels_batch)
    return np.array(data, dtype=np.float32).reshape((60000, 3072)), np.array(labels, dtype=np.int64).reshape(-1)

def split_dataset(data, labels, split):
    data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=1-split, shuffle=True)
    return data_train, data_test, labels_train, labels_test