From 55ce9e8d635adf084c111572f64964da4c7ca2ac Mon Sep 17 00:00:00 2001 From: Delorme Antonin <antonin.delorme@etu.ec-lyon.fr> Date: Fri, 10 Nov 2023 19:03:46 +0000 Subject: [PATCH] Upload New File --- read_cifar.py | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 read_cifar.py diff --git a/read_cifar.py b/read_cifar.py new file mode 100644 index 0000000..602ac3e --- /dev/null +++ b/read_cifar.py @@ -0,0 +1,54 @@ +import pickle +import numpy as np +import random as rd + + +def read_cifar_batch(path): + """ path = "data\cifar-10-batches-py\data_batch_1" + par exemple """ + with open(path, 'rb') as fo: + dict = pickle.load(fo, encoding='bytes') + + labels=list(dict.items())[1][1] #labels[i] est le label de l'ième image + data=list(dict.items())[2][1] #data[i] sont les 3072 pixel de l'image i + return (labels,data) + +def read_cifar(path): + """ path="data\cifar-10-batches-py" par exemple """ + (labels,data)=read_cifar_batch(path+"\\test_batch") + + for i in range(1,6): + data=np.concatenate((data,read_cifar_batch(path+"\\data_batch_"+str(i))[1]),axis=0) + labels=labels+read_cifar_batch(path+"\\data_batch_"+str(i))[0] + return (labels,data) + +def split_dataset(labels,data,split): + split=round(split*len(labels)) + test=[] + while len(test) != split: + Nb=rd.randint(0,len(labels)-1) + if Nb not in test : + test.append(Nb) + train=[i for i in range(len(labels)) if i not in test] + + data_train=data[train] + data_test=data[test] + labels_test=[] + labels_train=[] + for i in test: + labels_test.append(labels[i]) + for j in train: + labels_train.append(labels[j]) + + return(data_train,labels_train,data_test,labels_test) + + + +if __name__ == "__main__": + #path="data\\cifar-10-batches-py\\test_batch" + #data=read_cifar_batch(path) + path="data\\cifar-10-batches-py" + labels,data=read_cifar(path) + res=split_dataset(labels,data,0.1) + + -- GitLab