diff --git a/read_cifar.py b/read_cifar.py index 5d0736d7dceba6aac6ed351bac7e00697e80b2b6..827588b3cfb53e055ce95cd9b6f2c7bf0afeaaad 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -11,9 +11,19 @@ def read_cifar_batch(path): data = batch[b'data'] labels = batch[b'labels'] return np.float32(data), np.int64(labels) - - + +def read_cifar(folder_path): + data, labels = read_cifar_batch("./data/cifar-10-python.tar/cifar-10-batches-py~/cifar-10-batches-py/test_batch") + for i in range(1,5): + data = np.concatenate((data, read_cifar_batch(folder_path + "/data_batch_" + str(i))[0])) + labels = np.concatenate((labels, read_cifar_batch(folder_path + "/data_batch_" + str(i))[1])) + return data, labels + + + + if __name__ == "__main__": - v1, v2 = read_cifar_batch("./data/cifar-10-batches-py/cifar-10-batches-py~/cifar-10-batches-py/data_batch_1") - print(v1) - print(v2) + data, labels = read_cifar("./data/cifar-10-python.tar/cifar-10-batches-py~/cifar-10-batches-py") + print(data) + print(labels) +