import numpy as np
import pickle
import os


def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict


def read_cifar_batch(file):
    dict = unpickle(file)
    
    data = np.array(dict[b'data'], dtype=np.float32)
    labels = np.array(dict[b'labels'], dtype=np.int64)
    
    return data, labels


def read_cifar (batch_dir) :
    data_batches = []
    label_batches = []
    
    for i in range(1,6) :
        batch_name = f'data_batch_{i}'
        batch_path = os.path.join(batch_dir, batch_name)
        data, labels = read_cifar_batch(batch_path)
        data_batches.append(data)
        label_batches.append(labels)

    test_batch_filename = 'test_batch'
    test_batch_path = os.path.join(batch_dir, test_batch_filename)
    data_test, labels_test = read_cifar_batch(test_batch_path)
    data_batches.append(data_test)
    label_batches.append(labels_test)

    data = np.concatenate(data_batches, axis=0)
    labels = np.concatenate(label_batches, axis=0)

    return data, labels


def split_dataset (data, labels, split) :
    if len(data) != len(labels) :
        raise ValueError("data and labels should have the same size in the first dimension")

    if split< 0 or split > 1 :
        raise ValueError("Split ratio should be between 0 and 1")

    data_size = len(data)
    shuffled_indexes = np.random.permutation(data_size)
    train_set_size = int(data_size*split)

    data_train = []
    data_test = []
    labels_train = []
    labels_test = []

    for i in range (train_set_size+1) :
        index = shuffled_indexes[i]
        data_train.append(data[index])
        labels_train.append(labels[index])

    for j in range (train_set_size+1, data_size) :
        index = shuffled_indexes[j]
        data_test.append(data[index])
        labels_test.append(labels[index])

    data_train = np.array(data_train)
    data_test = np.array(data_test)
    
    return data_train, labels_train, data_test, labels_test


if __name__ == "__main__":
    batch = read_cifar('data/cifar-10-batches-py/')
    data = batch[0]
    labels = batch[1]
    split = 0.9
    data_train, labels_train, data_test, labels_test = split_dataset(data, labels, split)
    print(len(data_train), len(data_test), len(data_train)+len(data_test), len(labels_train)+len(labels_test))