"""import numpy"""
import numpy as np
import pickle
import os

def read_cifar_batch(batch_path):
    """F"""
    with open(batch_path, 'rb') as file:
        batch_data = pickle.load(file, encoding='bytes')
    data = np.array(batch_data[b'data'], dtype=np.float32)
    labels = np.array(batch_data[b'labels'], dtype=np.int64)
    return data, labels

def read_cifar(path_folder):
    data = np.empty((0, 3072), dtype=np.float32)
    labels = np.empty((0), dtype=np.int64)
    for filename in os.listdir(path_folder):
        if filename.startswith("data_batch") or filename == "test_batch":
            batch_path = os.path.join(path_folder, filename)
            d, l = read_cifar_batch(batch_path)
            data = np.concatenate((data, d), axis=0)
            labels = np.concatenate((labels, l), axis=None)
    return(data,labels)

def split_dataset(data, labels, split_factor):
    """fonction"""
    num_samples = len(data)
    shuffled_indices = np.random.permutation(num_samples)
    split_index = int(num_samples * split_factor)

    data_train = data[shuffled_indices[:split_index],:]
    labels_train = labels[shuffled_indices[:split_index]]
    data_test = data[shuffled_indices[split_index:],:]
    labels_test = labels[shuffled_indices[split_index:]]

    return data_train, labels_train, data_test, labels_test


if __name__ == "__main__":
    #read_cifar_batch("data/cifar-10-batches-py/data_batch_1")
    d, l = read_cifar("data/cifar-10-batches-py")
    d_1, l_1, d_2, l_2 = split_dataset(d, l, 0.5)
    print(l_1[0:10])