import glob import numpy as np import pickle def read_cifar_batch(batch_path): # read a batch of cifar data with open(batch_path, 'rb') as f: batch = pickle.load(f, encoding='bytes') data=np.array(batch[b'data'],dtype=np.float32)/255.0 labels=np.array(batch[b'labels'],dtype=np.int64) return data, labels def read_cifar(directory): # read all cifar data in a directory files = glob.glob(f'{directory}/*_batch*') data = np.empty((0, 3072), dtype=np.float32) labels = np.empty((0), dtype=np.int64) for file in files: batch_data, batch_labels = read_cifar_batch(file) data = np.vstack((data, batch_data)) labels = np.hstack((labels, batch_labels)) #print(data.shape, labels.shape) return data, labels