import pickle
import numpy as np
import random as rd


def read_cifar_batch(path):
    """ path = "data\cifar-10-batches-py\data_batch_1" 
    par exemple """
    with open(path, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    
    labels=list(dict.items())[1][1] #labels[i] est le label de l'ième image
    data=list(dict.items())[2][1] #data[i] sont les 3072 pixel de l'image i
    return (labels,data)

def read_cifar(path):
    """ path="data\cifar-10-batches-py" par exemple """
    (labels,data)=read_cifar_batch(path+"\\test_batch")
    
    for i in range(1,6):
        data=np.concatenate((data,read_cifar_batch(path+"\\data_batch_"+str(i))[1]),axis=0)
        labels=labels+read_cifar_batch(path+"\\data_batch_"+str(i))[0]
    return (labels,data)

def split_dataset(labels,data,split):
    split=round(split*len(labels))
    test=[]
    while len(test) != split:
        Nb=rd.randint(0,len(labels)-1)
        if Nb not in test :
            test.append(Nb)
    train=[i for i in range(len(labels)) if i not in test]
    
    data_train=data[train]
    data_test=data[test]
    labels_test=[]
    labels_train=[]
    for i in test:
        labels_test.append(labels[i])
    for j in train:
        labels_train.append(labels[j])
        
    return(data_train,labels_train,data_test,labels_test)
            

    
if __name__ == "__main__":
    #path="data\\cifar-10-batches-py\\test_batch"
    #data=read_cifar_batch(path)
    path="data\\cifar-10-batches-py"
    labels,data=read_cifar(path)
    res=split_dataset(labels,data,0.1)