diff --git a/tests/learn_once_mse_test.py b/tests/learn_once_mse_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dc57e4c8d120a3a76468af3f9481ce077f0b9d08 --- /dev/null +++ b/tests/learn_once_mse_test.py @@ -0,0 +1,21 @@ +from mlp import learn_once_mse +import numpy as np + +def learn_once_mse_test(): + + + N = 100 + d_in = 30 + d_out = 1 + d_h = 5 + + train = np.random.rand(N, d_in) + targets = np.random.randint(10, size=(N, d_out)) + + w1 = 2 * np.random.rand(d_in, d_h) - 1 + b1 = np.zeros((1, d_h)) + w2 = 2 * np.random.rand(d_h, d_out) - 1 + b2 = np.zeros((1, d_out)) + + w1, b1, w2, b2, loss = learn_once_mse(w1, b1, w2, b2, train, targets, learning_rate=0.01) +