# -*- coding: utf-8 -*- """ Created on Fri Oct 20 16:04:49 2023 @author: oscar """ import os import numpy as np import pickle import pickle def read_cifar_batch(batch_path): with open(batch_path, 'rb') as f: batch = pickle.load(f, encoding='bytes') data = np.array(batch.get(b'data')) labels = np.array(batch.get(b'labels')) return data, labels def read_cifar (batch_dir) : data_batches = [] label_batches = [] for i in range(1,6) : batch_filename = f'data_batch_{i}' batch_path = os.path.join(batch_dir, batch_filename) data, labels = read_cifar_batch(batch_path) data_batches.append(data) label_batches.append(labels) test_batch_filename = 'test_batch' test_batch_path = os.path.join(batch_dir, test_batch_filename) data_test, labels_test = read_cifar_batch(test_batch_path) data_batches.append(data_test) label_batches.append(labels_test) data = np.concatenate(data_batches, axis=0) labels = np.concatenate(label_batches, axis=0) return data, labels def split_dataset(data, labels, split) : number_total = data.shape[0] number_train = int(number_total * split) indices = np.arange(number_total) np.random.shuffle(indices) indices_train = indices[:number_train] indices_test = indices[number_train:] data_train = data[indices_train] labels_train = labels[indices_train] data_test = data[indices_test] labels_test = labels[indices_test] return(data_train, labels_train, data_test, labels_test) if __name__ == "__main__": file = "./data/cifar-10-python/" data, labels = read_cifar(file) res = split_dataset(data, labels, 0.8)