import numpy as np

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict


def read_cifar_batch(path):
    dict = unpickle(path)
    return dict[b'data'].astype(np.float32), np.array(dict[b'labels']).astype(np.int64)

def read_cifar(path):
    batch_names = [path+"data_batch_"+str(i) for i in [1,2,3,4,5]]
    batch_names.append(path+"test_batch")
    data_list = []
    label_list = []
    for batch_name in batch_names:
        d,l = read_cifar_batch(batch_name)
        data_list.append(d)
        label_list.append(l)

    data = np.row_stack(data_list)
    labels = np.concatenate(label_list) 
    print(type(labels))

    
    return data,labels
    
def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]

def split_dataset(data, labels, split):
    d,l=unison_shuffled_copies(data, labels)
    n=len(d)
    split_off=int(n*split)
    data_train=d[:split_off]
    labels_train=l[:split_off]
    data_test=d[split_off:]
    labels_test=l[split_off:]
    return data_train,labels_train,data_test,labels_test

def split_dataset_n_fold(data, labels, split,N):
    d,l=unison_shuffled_copies(data, labels)
    ks = range(1 , N+1)
    splits_off=[int(k*range) for k in ks]
    previous_k = 0
    for k in splits_off:
        data=d[previous_k:k]
        labels=l[previous_k:k]
        previous_k = k
    return data,labels

def read_cifar_preprocess(path):
    batch_names = [path + "data_batch_" + str(i) for i in [1, 2, 3, 4, 5]]
    batch_names.append(path + "test_batch")
    data_list = []
    label_list = []
    for batch_name in batch_names:
        d, l = read_cifar_batch(batch_name)
        data_list.append(d)
        label_list.append(l)

    data = np.row_stack(data_list)
    labels = np.concatenate(label_list)

    num_channels = 3
    image_size = 32
    channel_size = image_size * image_size

    data = data.reshape((-1, num_channels, image_size, image_size))

    means = np.mean(data, axis=(0, 2, 3))
    stds = np.std(data, axis=(0, 2, 3)) 


    data[:, 0, :, :] = (data[:, 0, :, :] - means[0]) / stds[0]  # Rouge
    data[:, 1, :, :] = (data[:, 1, :, :] - means[1]) / stds[1]  # Vers
    data[:, 2, :, :] = (data[:, 2, :, :] - means[2]) / stds[2]  # Bleu

    data = data.reshape((-1, channel_size * num_channels))

    return data, labels




if __name__ == "__main__":
    d,l = read_cifar_preprocess("data/cifar-10-batches-py/")
   