diff --git a/read_cifar.py b/read_cifar.py index abba3fae17d1007e47e2551059dd4adf24b8e528..2aeb021bed6dfd1767498bd1196c82e0eef67a2b 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -7,8 +7,29 @@ def read_cifar_batch(file): labels = np.array(dict[b'labels']).astype('int64') return data, labels -vect1= read_cifar_batch("data/cifar-10-batches-py/data_batch_1") +#vect1= read_cifar_batch("data/cifar-10-batches-py/data_batch_1") #print(vect1) -#def read_cifar \ No newline at end of file +def read_cifar(directory): + all_data = [] + all_labels = [] + + for i in range(1,6): + data_v, labels_v = read_cifar_batch(f'{directory}/data_batch_{i}') + all_data.append(data_v) + all_labels.append(labels_v) + + data_v, labels_v = read_cifar_batch(f'{directory}/test_batch') + all_data.append(data_v) + all_labels.append(labels_v) + + all_data = np.concatenate(all_data, axis = 0) + all_labels = np.concatenate(all_labels, axis = 0) + + return(all_data, all_labels) + + +#vect2= read_cifar("data/cifar-10-batches-py") +#print(vect2) +