From 98dc5f386ca0452737b155a0fc4cc5d5b796e7de Mon Sep 17 00:00:00 2001 From: lucile <lucile.audard@ecl20.ec-lyon.fr> Date: Thu, 9 Nov 2023 11:29:02 +0100 Subject: [PATCH] Update mlp.py --- mlp.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/mlp.py b/mlp.py index 1b7402f..8518643 100644 --- a/mlp.py +++ b/mlp.py @@ -58,7 +58,22 @@ def learn_once_cross_entropy(w1, b1, w2, b2, data, labels_train, learning_rate): targets_one_hot = one_hot(labels_train) # target as a one-hot encoding for the desired labels # cross-entropy loss - loss = + loss = -np.sum(targets_one_hot * np.log(predictions)) / N + + # Backpropagation + d_z2 = a2 - targets_one_hot + d_w2 = np.dot(a1.T, d_z2) / N + d_b2 = d_z2 / N + d_a1 = np.dot(d_z2, w2.T) + d_z1 = d_a1 * z1 * (1 - a1) + d_w1 = np.dot(a0.T, d_z1) / N + d_b1 = d_z1 / N + + # Calculation of the updated weights and biases of the network with gradient descent method + w1 -= learning_rate * d_w1 + w2 -= learning_rate * d_w2 + b2 -= learning_rate * d_b2 + b1 -= learning_rate * d_b1 return w1, b1, w2, b2, loss -- GitLab