import numpy as np
import pickle
def read_cifar_batch(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
        data = np.array(dict[b'data']).astype('float32')
        labels = np.array(dict[b'labels']).astype('int64')
    return data, labels

#vect1= read_cifar_batch("data/cifar-10-batches-py/data_batch_1")

#print(vect1)

def read_cifar(directory):
    all_data = []
    all_labels = []

    for i in range(1,6):
        data_v, labels_v = read_cifar_batch(f'{directory}/data_batch_{i}')
        all_data.append(data_v)
        all_labels.append(labels_v)

    data_v, labels_v = read_cifar_batch(f'{directory}/test_batch')
    all_data.append(data_v)
    all_labels.append(labels_v)

    all_data = np.concatenate(all_data, axis = 0)
    all_labels = np.concatenate(all_labels, axis = 0)

    return(all_data, all_labels)


#vect2= read_cifar("data/cifar-10-batches-py")
#print(vect2)