import pickle import numpy as np import glob def unpickle(file): with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') return dict def read_cifar_batch(path_to_batch_file): dict=unpickle(path_to_batch_file) data=np.array(dict[b'data'],dtype=np.float32)/255 labels=np.array(dict[b'labels'],dtype=np.int64) return data,labels def read_cifar(path_to_batches_files): files=glob.glob(path_to_batches_files) data,labels=read_cifar_batch(files[0]) for i in range(1,len(files)): data_temp,labels_temp=read_cifar_batch(files[i]) data=np.concatenate((data,data_temp),axis=0) labels=np.concatenate((labels,labels_temp),axis=0) return data,labels if __name__ == "__main__": path="image-classification/data/cifar-10-batches-py/data_batch_1" read_cifar_batch(path) #plot the 9 first image of the batch import matplotlib.pyplot as plt data,labels=read_cifar_batch(path) fig, axes = plt.subplots(3, 3) fig.subplots_adjust(hspace=0.6, wspace=0.3) for i, ax in enumerate(axes.flat): ax.imshow(data[i].reshape(3, 32, 32).transpose([1, 2, 0])) ax.set_xticks([]) ax.set_yticks([]) ax.set_xlabel(labels[i]) plt.show()