import numpy as np import pickle def read_cifar_batch(file): with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') data = np.array(dict[b'data']).astype('float32') labels = np.array(dict[b'labels']).astype('int64') return data, labels #vect1= read_cifar_batch("data/cifar-10-batches-py/data_batch_1") #print(vect1) def read_cifar(directory): all_data = [] all_labels = [] for i in range(1,6): data_v, labels_v = read_cifar_batch(f'{directory}/data_batch_{i}') all_data.append(data_v) all_labels.append(labels_v) data_v, labels_v = read_cifar_batch(f'{directory}/test_batch') all_data.append(data_v) all_labels.append(labels_v) all_data = np.concatenate(all_data, axis = 0) all_labels = np.concatenate(all_labels, axis = 0) return(all_data, all_labels) #vect2= read_cifar("data/cifar-10-batches-py") #print(vect2)