Skip to content
Snippets Groups Projects
Select Git revision
  • e5b01982650b55893e5009fb4834cc0ec6545db9
  • master default protected
2 results

test_heap.py

Blame
  • Forked from Vuillemot Romain / INF-TC1
    Source project has a limited visibility.
    read_cifar_batch.py 1.23 KiB
    import pickle
    import numpy as np
    import glob
    def unpickle(file):
        with open(file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
        return dict
    
    def read_cifar_batch(path_to_batch_file):
        dict=unpickle(path_to_batch_file)
        data=np.array(dict[b'data'],dtype=np.float32)/255
        labels=np.array(dict[b'labels'],dtype=np.int64)
        return data,labels
    
    def read_cifar(path_to_batches_files):
        files=glob.glob(path_to_batches_files)
        data,labels=read_cifar_batch(files[0])
        for i in range(1,len(files)):
            data_temp,labels_temp=read_cifar_batch(files[i])
            data=np.concatenate((data,data_temp),axis=0)
            labels=np.concatenate((labels,labels_temp),axis=0)
        return data,labels
    
    
    if __name__ == "__main__":
        path="image-classification/data/cifar-10-batches-py/data_batch_1"
        read_cifar_batch(path)
        #plot the 9 first image of the batch    
        import matplotlib.pyplot as plt
        data,labels=read_cifar_batch(path)
        fig, axes = plt.subplots(3, 3)
        fig.subplots_adjust(hspace=0.6, wspace=0.3)
        for i, ax in enumerate(axes.flat):
            ax.imshow(data[i].reshape(3, 32, 32).transpose([1, 2, 0]))
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_xlabel(labels[i])
        plt.show()