diff --git a/read_cifar.py b/read_cifar.py new file mode 100644 index 0000000000000000000000000000000000000000..602ac3ee0e416e7e8af21a3e4d9b4b471998efec --- /dev/null +++ b/read_cifar.py @@ -0,0 +1,54 @@ +import pickle +import numpy as np +import random as rd + + +def read_cifar_batch(path): + """ path = "data\cifar-10-batches-py\data_batch_1" + par exemple """ + with open(path, 'rb') as fo: + dict = pickle.load(fo, encoding='bytes') + + labels=list(dict.items())[1][1] #labels[i] est le label de l'ième image + data=list(dict.items())[2][1] #data[i] sont les 3072 pixel de l'image i + return (labels,data) + +def read_cifar(path): + """ path="data\cifar-10-batches-py" par exemple """ + (labels,data)=read_cifar_batch(path+"\\test_batch") + + for i in range(1,6): + data=np.concatenate((data,read_cifar_batch(path+"\\data_batch_"+str(i))[1]),axis=0) + labels=labels+read_cifar_batch(path+"\\data_batch_"+str(i))[0] + return (labels,data) + +def split_dataset(labels,data,split): + split=round(split*len(labels)) + test=[] + while len(test) != split: + Nb=rd.randint(0,len(labels)-1) + if Nb not in test : + test.append(Nb) + train=[i for i in range(len(labels)) if i not in test] + + data_train=data[train] + data_test=data[test] + labels_test=[] + labels_train=[] + for i in test: + labels_test.append(labels[i]) + for j in train: + labels_train.append(labels[j]) + + return(data_train,labels_train,data_test,labels_test) + + + +if __name__ == "__main__": + #path="data\\cifar-10-batches-py\\test_batch" + #data=read_cifar_batch(path) + path="data\\cifar-10-batches-py" + labels,data=read_cifar(path) + res=split_dataset(labels,data,0.1) + +