Wrong label_train shape when spliting
When I run
from read_cifar import read_cifar, split_dataset
data, labels = read_cifar("data/cifar-10-batches-py")
data_train, labels_train, data_test, labels_test = split_dataset(
data, labels, split=0.9
)
assert data_train.shape == (54000, 3072)
assert labels_train.shape == (54000,)
I get
Traceback (most recent call last):
File "/home/qgallouedec/image-classification/test.py", line 8, in <module>
assert labels_train.shape == (54000,)
AssertionError