import numpy as np
import read_cifar as rc
from read_cifar import read_cifar_batch
from read_cifar import read_cifar
def test_read_cifar_batch():
   # Test read_cifar_batch function
   batch_path = "data\data_batch_1"
   data, labels = read_cifar_batch(batch_path)

   # Check that data has the right shape and type
   assert data.shape == (10000, 3072)
   assert data.dtype == np.float32

   # Check that labels has the right shape and type
   assert labels.shape == (10000,)
   assert labels.dtype == np.int64
   print("All tests passed successfully.")
   


def  test_read_cifar():
    # Test read_cifar function
    data, labels = read_cifar('data')

    # Check that data has the right shape and type
    assert data.shape == (60000, 3072)
    assert data.dtype == np.float32

    # Check that labels has the right shape and type
    assert labels.shape == (60000,)
    assert labels.dtype == np.int64
    print("All tests passed successfully.")

def test_split_dataset():
    data = np.random.randn(150, 4)
    labels = np.random.randn(150)
    split = 0.8
    data_train, labels_train, data_test, labels_test = rc.split_dataset(data, labels, split)

    total_size = data_train.shape[0] + data_test.shape[0]

    assert total_size == len(data)
    assert len(labels_train) == len(data_train)
    assert len(labels_test) == len(data_test)
    
    print("All tests passed successfully.")
