import pickle
import numpy as np


def unpickle(file):
    with open(file, 'rb') as fo:
        batch = pickle.load(fo, encoding='bytes')
    return batch

def read_cifar_batch(path):
    batch = unpickle(path)
    data = batch[b'data']
    labels = batch[b'labels']
    return np.float32(data), np.int64(labels)

def read_cifar(folder_path):
    data, labels = read_cifar_batch("./data/cifar-10-batches-py/test_batch")
    for i in range(1,5):
        data = np.concatenate((data, read_cifar_batch(folder_path + "/data_batch_" + str(i))[0]))
        labels = np.concatenate((labels, read_cifar_batch(folder_path + "/data_batch_" + str(i))[1]))
    return data, labels

def split_dataset(data, labels, split):
    index = int(split * len(data))
    data_train, data_test = np.split(data, index)
    labels_train, labels_test = np.split(labels, index)
    return data_train, labels_train, data_test, labels_test
    


if __name__ == "__main__":
    data, labels = read_cifar("./data/cifar-10-batches-py")
    print(data)
    print(labels)