import pickle
import numpy as np
import glob
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def read_cifar_batch(path_to_batch_file):
    dict=unpickle(path_to_batch_file)
    data=np.array(dict[b'data'],dtype=np.float32)/255
    labels=np.array(dict[b'labels'],dtype=np.int64)
    return data,labels

def read_cifar(path_to_batches_files):
    files=glob.glob(path_to_batches_files)
    data,labels=read_cifar_batch(files[0])
    for i in range(1,len(files)):
        data_temp,labels_temp=read_cifar_batch(files[i])
        data=np.concatenate((data,data_temp),axis=0)
        labels=np.concatenate((labels,labels_temp),axis=0)
    return data,labels


if __name__ == "__main__":
    path="image-classification/data/cifar-10-batches-py/data_batch_1"
    read_cifar_batch(path)
    #plot the 9 first image of the batch    
    import matplotlib.pyplot as plt
    data,labels=read_cifar_batch(path)
    fig, axes = plt.subplots(3, 3)
    fig.subplots_adjust(hspace=0.6, wspace=0.3)
    for i, ax in enumerate(axes.flat):
        ax.imshow(data[i].reshape(3, 32, 32).transpose([1, 2, 0]))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel(labels[i])
    plt.show()