Skip to content
Snippets Groups Projects
Unverified Commit 91871cd9 authored by Jangberry (Nomad-Debian)'s avatar Jangberry (Nomad-Debian)
Browse files

Some more tests

parent 79448985
Branches
No related tags found
No related merge requests found
......@@ -168,11 +168,11 @@ def train_mlp(w1, b1, w2, b2, data_train, labels_train, learning_rate, num_epoch
for i in range(num_epochs):
w1, b1, w2, b2, loss = learn_once_cross_entropy(
w1, b1, w2, b2, data_train, labels_train, learning_rate)
acc.append(test_mlp(w1, b1, w2, b2, data_train, labels_train))
acc.append(mlp_test(w1, b1, w2, b2, data_train, labels_train))
return w1, b1, w2, b2, acc
def test_mlp(w1, b1, w2, b2, data_test, labels_test):
def mlp_test(w1, b1, w2, b2, data_test, labels_test):
"""
Tests the MLP
......
......@@ -27,9 +27,39 @@ def test_learn_once_mse():
assert loss2 < loss
def test_one_hot():
indices = np.array([2, 0, 1])
result = one_hot(indices)
expected = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
np.testing.assert_array_equal(result, expected)
def test_learn_once_cross_entropy():
N = 30 # number of input data
d_in = 3 # input dimension
d_h = 3 # number of neurons in the hidden layer
d_out = 5 # output dimension (number of neurons of the output layer)
w1 = 2 * np.random.rand(d_in, d_h) - 1 # first layer weights
b1 = np.zeros((1, d_h)) # first layer biaises
w2 = 2 * np.random.rand(d_h, d_out) - 1 # second layer weights
b2 = np.zeros((1, d_out)) # second layer biaises
data = np.random.rand(N, d_in) # create a random data
targets = np.random.randint(1, d_out, N) # create a random targets
w1, b1, w2, b2, loss = learn_once_cross_entropy(
w1, b1, w2, b2, data, targets, 0.1)
w1, b1, w2, b2, loss2 = learn_once_cross_entropy(
w1, b1, w2, b2, data, targets, 0.1)
assert loss2 < loss
def test_run_mlp_training():
data, labels = read_cifar.read_cifar("data/cifar-10-batches-py/")
data_train, labels_train, data_test, labels_test = read_cifar.split_dataset(
data, labels, 0.8)
w1, b1, w2, b2, acc = run_mlp_training(
data_train, labels_train, data_test, labels_test, 0.1, 1)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment