import pickle import numpy as np import random as rd def read_cifar_batch(path): """ path = "data\cifar-10-batches-py\data_batch_1" par exemple """ with open(path, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') labels=list(dict.items())[1][1] #labels[i] est le label de l'ième image data=list(dict.items())[2][1] #data[i] sont les 3072 pixel de l'image i return (labels,data) def read_cifar(path): """ path="data\cifar-10-batches-py" par exemple """ (labels,data)=read_cifar_batch(path+"\\test_batch") for i in range(1,6): data=np.concatenate((data,read_cifar_batch(path+"\\data_batch_"+str(i))[1]),axis=0) labels=labels+read_cifar_batch(path+"\\data_batch_"+str(i))[0] return (labels,data) def split_dataset(labels,data,split): split=round(split*len(labels)) test=[] while len(test) != split: Nb=rd.randint(0,len(labels)-1) if Nb not in test : test.append(Nb) train=[i for i in range(len(labels)) if i not in test] data_train=data[train] data_test=data[test] labels_test=[] labels_train=[] for i in test: labels_test.append(labels[i]) for j in train: labels_train.append(labels[j]) return(data_train,labels_train,data_test,labels_test) if __name__ == "__main__": #path="data\\cifar-10-batches-py\\test_batch" #data=read_cifar_batch(path) path="data\\cifar-10-batches-py" labels,data=read_cifar(path) res=split_dataset(labels,data,0.1)