from mlp import one_hot
import numpy as np

def test_one_hot():
    x = list(range(10))
    oh = one_hot(x)
    assert oh.shape == (len(x), 10)
    assert (np.sum(oh, axis=1) == np.ones((1, 10))).all()