From 9206be845c968d090f2d1a781873b88a3ff21cb1 Mon Sep 17 00:00:00 2001 From: oscarchaufour <101994223+oscarchaufour@users.noreply.github.com> Date: Fri, 20 Oct 2023 17:40:41 +0200 Subject: [PATCH] knn and read cifar knn creation and cread_cifar three functions ok --- knn.py | 7 +++++++ read_cifar.py | 21 ++++++++++++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) create mode 100644 knn.py diff --git a/knn.py b/knn.py new file mode 100644 index 0000000..72ea8fd --- /dev/null +++ b/knn.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Oct 20 17:39:37 2023 + +@author: oscar +""" + diff --git a/read_cifar.py b/read_cifar.py index 4fe8169..cc2befc 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -18,7 +18,6 @@ def read_cifar_batch(batch_path): data = np.array(batch.get(b'data')) labels = np.array(batch.get(b'labels')) - return data, labels def read_cifar (batch_dir) : @@ -43,6 +42,22 @@ def read_cifar (batch_dir) : return data, labels +def split_dataset(data, labels, split) : + + number_total = data.shape[0] + number_train = int(number_total * split) + indices = np.arange(number_total) + np.random.shuffle(indices) + indices_train = indices[:number_train] + indices_test = indices[number_train:] + data_train = data[indices_train] + labels_train = labels[indices_train] + data_test = data[indices_test] + labels_test = labels[indices_test] + + return(data_train, labels_train, data_test, labels_test) + if __name__ == "__main__": - file = "./data/cifar-10-python/data_batch_1" - read_cifar_batch(file) + file = "./data/cifar-10-python/" + data, labels = read_cifar(file) + res = split_dataset(data, labels, 0.8) -- GitLab