Skip to content

Wrong label_train shape when spliting

When I run

from read_cifar import read_cifar, split_dataset

data, labels = read_cifar("data/cifar-10-batches-py")
data_train, labels_train, data_test, labels_test = split_dataset(
    data, labels, split=0.9
)
assert data_train.shape == (54000, 3072)
assert labels_train.shape == (54000,)

I get

Traceback (most recent call last):
  File "/home/qgallouedec/image-classification/test.py", line 8, in <module>
    assert labels_train.shape == (54000,)
AssertionError