import glob
import numpy as np
import pickle

def read_cifar_batch(batch_path):
    # read a batch of cifar data
    with open(batch_path, 'rb') as f:
        batch = pickle.load(f, encoding='bytes')
    data=np.array(batch[b'data'],dtype=np.float32)/255.0
    labels=np.array(batch[b'labels'],dtype=np.int64)
  
    return data, labels

def read_cifar(directory):
    # read all cifar data in a directory
    files = glob.glob(f'{directory}/*_batch*')
    data = np.empty((0, 3072), dtype=np.float32)
    labels = np.empty((0), dtype=np.int64)
    for file in files:
        batch_data, batch_labels = read_cifar_batch(file)
        data = np.vstack((data, batch_data))
        labels = np.hstack((labels, batch_labels))
    #print(data.shape, labels.shape)
    return data, labels