From 9206be845c968d090f2d1a781873b88a3ff21cb1 Mon Sep 17 00:00:00 2001
From: oscarchaufour <101994223+oscarchaufour@users.noreply.github.com>
Date: Fri, 20 Oct 2023 17:40:41 +0200
Subject: [PATCH] knn and read cifar

knn creation and cread_cifar three functions ok
---
 knn.py        |  7 +++++++
 read_cifar.py | 21 ++++++++++++++++++---
 2 files changed, 25 insertions(+), 3 deletions(-)
 create mode 100644 knn.py

diff --git a/knn.py b/knn.py
new file mode 100644
index 0000000..72ea8fd
--- /dev/null
+++ b/knn.py
@@ -0,0 +1,7 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Oct 20 17:39:37 2023
+
+@author: oscar
+"""
+
diff --git a/read_cifar.py b/read_cifar.py
index 4fe8169..cc2befc 100644
--- a/read_cifar.py
+++ b/read_cifar.py
@@ -18,7 +18,6 @@ def read_cifar_batch(batch_path):
         data = np.array(batch.get(b'data'))
         labels = np.array(batch.get(b'labels'))
 
- 
     return data, labels
 
 def read_cifar (batch_dir) :
@@ -43,6 +42,22 @@ def read_cifar (batch_dir) :
 
     return data, labels
 
+def split_dataset(data, labels, split) : 
+    
+    number_total = data.shape[0]
+    number_train = int(number_total * split)
+    indices = np.arange(number_total)
+    np.random.shuffle(indices)
+    indices_train = indices[:number_train]
+    indices_test = indices[number_train:]
+    data_train = data[indices_train]
+    labels_train = labels[indices_train]
+    data_test = data[indices_test]
+    labels_test = labels[indices_test]
+    
+    return(data_train, labels_train, data_test, labels_test)
+
 if __name__ == "__main__":
-    file = "./data/cifar-10-python/data_batch_1"
-    read_cifar_batch(file)
+    file = "./data/cifar-10-python/"
+    data, labels = read_cifar(file)
+    res = split_dataset(data, labels, 0.8)
-- 
GitLab