From 55ce9e8d635adf084c111572f64964da4c7ca2ac Mon Sep 17 00:00:00 2001
From: Delorme Antonin <antonin.delorme@etu.ec-lyon.fr>
Date: Fri, 10 Nov 2023 19:03:46 +0000
Subject: [PATCH] Upload New File

---
 read_cifar.py | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 54 insertions(+)
 create mode 100644 read_cifar.py

diff --git a/read_cifar.py b/read_cifar.py
new file mode 100644
index 0000000..602ac3e
--- /dev/null
+++ b/read_cifar.py
@@ -0,0 +1,54 @@
+import pickle
+import numpy as np
+import random as rd
+
+
+def read_cifar_batch(path):
+    """ path = "data\cifar-10-batches-py\data_batch_1" 
+    par exemple """
+    with open(path, 'rb') as fo:
+        dict = pickle.load(fo, encoding='bytes')
+    
+    labels=list(dict.items())[1][1] #labels[i] est le label de l'ième image
+    data=list(dict.items())[2][1] #data[i] sont les 3072 pixel de l'image i
+    return (labels,data)
+
+def read_cifar(path):
+    """ path="data\cifar-10-batches-py" par exemple """
+    (labels,data)=read_cifar_batch(path+"\\test_batch")
+    
+    for i in range(1,6):
+        data=np.concatenate((data,read_cifar_batch(path+"\\data_batch_"+str(i))[1]),axis=0)
+        labels=labels+read_cifar_batch(path+"\\data_batch_"+str(i))[0]
+    return (labels,data)
+
+def split_dataset(labels,data,split):
+    split=round(split*len(labels))
+    test=[]
+    while len(test) != split:
+        Nb=rd.randint(0,len(labels)-1)
+        if Nb not in test :
+            test.append(Nb)
+    train=[i for i in range(len(labels)) if i not in test]
+    
+    data_train=data[train]
+    data_test=data[test]
+    labels_test=[]
+    labels_train=[]
+    for i in test:
+        labels_test.append(labels[i])
+    for j in train:
+        labels_train.append(labels[j])
+        
+    return(data_train,labels_train,data_test,labels_test)
+            
+
+    
+if __name__ == "__main__":
+    #path="data\\cifar-10-batches-py\\test_batch"
+    #data=read_cifar_batch(path)
+    path="data\\cifar-10-batches-py"
+    labels,data=read_cifar(path)
+    res=split_dataset(labels,data,0.1)
+    
+    
-- 
GitLab