## Imports
import numpy as np
import pickle


## QUESTION 2
def unpickle(file):
    # Source: https://www.cs.toronto.edu/~kriz/cifar.html
    with open(file, "rb") as fo:
        dict = pickle.load(fo, encoding="bytes")
    return dict


def read_cifar_batch(file):
    """Read a batch of the CIFAR dataset.
    Args:
        path: The path to the batch file.
    Returns:
        data: A np.float32 array of size batch_size x data_size, where batch_size is the number of available images in the batch, and data_size is the dimension of these images (number of numerical values describing one image).
        labels: A list of labels of size batch_size whose values correspond to the class code of the data of the same index in the matrix data. labels must be a np.int64 array.
    """
    dict = unpickle(file)
    data = np.array(dict[b"data"], dtype=np.float32)
    labels = np.array(dict[b"labels"], dtype=np.int64)
    return data, labels


## QUESTION 3
def read_cifar(path):
    """Read the whole CIFAR dataset.
    Args:
        path: The directory containing the CIFAR dataset.
    Returns:
        data: A np.float32 array of shape batch_size x data_size where batch_size is the number of available images in all_batches (including test_batch), and data_size is the the dimension of these data (number of numerical values describing the data).
        labels: A np.int64 array of size batch_size whose values correspond to the class code of the data of the same index in the matrix data.
    """
    data_batches = []
    label_batches = []
    for i in range(1, 6):
        data_batch, label_batch = read_cifar_batch(path + r"\data_batch_" + str(i))
        data_batches.append(data_batch)
        label_batches.append(label_batch)
    test_batch, test_label = read_cifar_batch(path + r"\test_batch")
    data_batches.append(test_batch)
    label_batches.append(test_label)
    data = np.concatenate(data_batches, axis=0)
    labels = np.concatenate(label_batches, axis=0)
    return data, labels


## QUESTION 4
def split_dataset(data, labels, split):
    """Split the dataset into a training set and a validation set. Data are shuffled before splitting.
    Args:
        data: A np.float32 array of shape batch_size x data_size where batch_size is the number of available images in all_batches (including test_batch), and data_size is the the dimension of these images (number of numerical values describing each image).
        labels: A np.int64 array of size batch_size whose values correspond to the class code of the data of the same index in the matrix data.
        split: A float between 0 and 1 which determines the split factor of the training set with respect to the test set. For example, if split = 0.8, then 80% of the data will be used for training and 20% for validation.
    Returns:
        data_train: A np.float32 array of shape split x batch_size x data_size, the training set.
        labels_train: A np.int64 array of shape split x batch_size, the labels of the training set.
        data_test: A np.float32 array of shape (1 - split) x batch_size x data_size, the validation set.
        labels_test: A np.int64 array of shape (1 - split) x batch_size, the labels of the validation set.
    """
    assert 0 <= split <= 1  # split must be between 0 and 1
    data_size = data.shape[0]
    # shuffle data and labels
    indices = np.arange(data_size)
    np.random.shuffle(indices)
    data = data[indices]
    labels = labels[indices]
    # split data and labels
    split_index = int(data_size * split)
    data_train = data[:split_index]
    labels_train = labels[:split_index]
    data_test = data[split_index:]
    labels_test = labels[split_index:]
    print("data_train.shape: ", data_train.shape)
    print("labels_train.shape: ", labels_train.shape)
    print("data_test.shape: ", data_test.shape)
    print("labels_test.shape: ", labels_test.shape)
    return data_train, labels_train, data_test, labels_test


if __name__ == "__main__":
    dict = unpickle(r"data\cifar-10-batches-py\data_batch_1")
    print(dict.keys())
    print(dict[b"data"].shape)
    print(dict[b"labels"][:10])

    data, labels = read_cifar_batch(r"data\cifar-10-batches-py\data_batch_1")
    print(data.dtype)
    print(labels.dtype)
    print(data.shape)
    print(labels.shape)

    data, labels = read_cifar(r"data\cifar-10-batches-py")
    print(data.dtype)
    print(labels.dtype)
    print(data.shape)
    print(labels.shape)

    data_train, labels_train, data_test, labels_test = split_dataset(data, labels, 0.8)
    print(data_train.dtype)
    print(labels_train.dtype)
    print(data_train.shape)
    print(labels_train.shape)
    print(data_test.dtype)
    print(labels_test.dtype)
    print(data_test.shape)
    print(labels_test.shape)