Skip to content
Snippets Groups Projects
Commit 55ce9e8d authored by Delorme Antonin's avatar Delorme Antonin
Browse files

Upload New File

parent e1fb51cb
No related branches found
No related tags found
No related merge requests found
import pickle
import numpy as np
import random as rd
def read_cifar_batch(path):
""" path = "data\cifar-10-batches-py\data_batch_1"
par exemple """
with open(path, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
labels=list(dict.items())[1][1] #labels[i] est le label de l'ième image
data=list(dict.items())[2][1] #data[i] sont les 3072 pixel de l'image i
return (labels,data)
def read_cifar(path):
""" path="data\cifar-10-batches-py" par exemple """
(labels,data)=read_cifar_batch(path+"\\test_batch")
for i in range(1,6):
data=np.concatenate((data,read_cifar_batch(path+"\\data_batch_"+str(i))[1]),axis=0)
labels=labels+read_cifar_batch(path+"\\data_batch_"+str(i))[0]
return (labels,data)
def split_dataset(labels,data,split):
split=round(split*len(labels))
test=[]
while len(test) != split:
Nb=rd.randint(0,len(labels)-1)
if Nb not in test :
test.append(Nb)
train=[i for i in range(len(labels)) if i not in test]
data_train=data[train]
data_test=data[test]
labels_test=[]
labels_train=[]
for i in test:
labels_test.append(labels[i])
for j in train:
labels_train.append(labels[j])
return(data_train,labels_train,data_test,labels_test)
if __name__ == "__main__":
#path="data\\cifar-10-batches-py\\test_batch"
#data=read_cifar_batch(path)
path="data\\cifar-10-batches-py"
labels,data=read_cifar(path)
res=split_dataset(labels,data,0.1)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment