diff --git a/read_cifar.py b/read_cifar.py new file mode 100644 index 0000000000000000000000000000000000000000..5d0736d7dceba6aac6ed351bac7e00697e80b2b6 --- /dev/null +++ b/read_cifar.py @@ -0,0 +1,19 @@ +import pickle +import numpy as np + +def unpickle(file): + with open(file, 'rb') as fo: + batch = pickle.load(fo, encoding='bytes') + return batch + +def read_cifar_batch(path): + batch = unpickle(path) + data = batch[b'data'] + labels = batch[b'labels'] + return np.float32(data), np.int64(labels) + + +if __name__ == "__main__": + v1, v2 = read_cifar_batch("./data/cifar-10-batches-py/cifar-10-batches-py~/cifar-10-batches-py/data_batch_1") + print(v1) + print(v2)