diff --git a/read_cifar.py b/read_cifar.py index da73e4ba56286ea2627a5b5c7655eaa2dd9df31d..c8cedb06042955a3dd45553a657c7b859909e020 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -1,3 +1,15 @@ import numpy as np -print("Test") \ No newline at end of file +import pickle + +def read_cifar_batch(batch_path): + + with open(batch_path, "rb") as f: + batch = pickle.load(f, encoding="bytes") + + data = np.array(batch[b'data'], dtype=np.float32) + labels = np.array(batch[b'labels'], dtype=np.int64) + + return data, labels + +print(read_cifar_batch('Data/cifar-10-batches-py/data_batch_2')) \ No newline at end of file