"""import numpy""" import numpy as np import pickle import os def read_cifar_batch(batch_path): """F""" with open(batch_path, 'rb') as file: batch_data = pickle.load(file, encoding='bytes') data = np.array(batch_data[b'data'], dtype=np.float32) labels = np.array(batch_data[b'labels'], dtype=np.int64) return data, labels def read_cifar(path_folder): data = np.empty((0, 3072), dtype=np.float32) labels = np.empty((0), dtype=np.int64) for filename in os.listdir(path_folder): if filename.startswith("data_batch") or filename == "test_batch": batch_path = os.path.join(path_folder, filename) d, l = read_cifar_batch(batch_path) data = np.concatenate((data, d), axis=0) labels = np.concatenate((labels, l), axis=None) return(data,labels) def split_dataset(data, labels, split_factor): """fonction""" num_samples = len(data) shuffled_indices = np.random.permutation(num_samples) split_index = int(num_samples * split_factor) data_train = data[shuffled_indices[:split_index],:] labels_train = labels[shuffled_indices[:split_index]] data_test = data[shuffled_indices[split_index:],:] labels_test = labels[shuffled_indices[split_index:]] return data_train, labels_train, data_test, labels_test if __name__ == "__main__": #read_cifar_batch("data/cifar-10-batches-py/data_batch_1") d, l = read_cifar("data/cifar-10-batches-py") d_1, l_1, d_2, l_2 = split_dataset(d, l, 0.5) print(l_1[0:10])