diff --git a/knn.py b/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..72ea8fda995de289b7a7e3d7c3071bf73bf09487 --- /dev/null +++ b/knn.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Oct 20 17:39:37 2023 + +@author: oscar +""" + diff --git a/read_cifar.py b/read_cifar.py index 4fe81697207c59b00d30f9a1ee5254f0cd187d59..cc2befce425e2c86d22f92388ac6dd09700c5a5e 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -18,7 +18,6 @@ def read_cifar_batch(batch_path): data = np.array(batch.get(b'data')) labels = np.array(batch.get(b'labels')) - return data, labels def read_cifar (batch_dir) : @@ -43,6 +42,22 @@ def read_cifar (batch_dir) : return data, labels +def split_dataset(data, labels, split) : + + number_total = data.shape[0] + number_train = int(number_total * split) + indices = np.arange(number_total) + np.random.shuffle(indices) + indices_train = indices[:number_train] + indices_test = indices[number_train:] + data_train = data[indices_train] + labels_train = labels[indices_train] + data_test = data[indices_test] + labels_test = labels[indices_test] + + return(data_train, labels_train, data_test, labels_test) + if __name__ == "__main__": - file = "./data/cifar-10-python/data_batch_1" - read_cifar_batch(file) + file = "./data/cifar-10-python/" + data, labels = read_cifar(file) + res = split_dataset(data, labels, 0.8)