From 98dc5f386ca0452737b155a0fc4cc5d5b796e7de Mon Sep 17 00:00:00 2001
From: lucile <lucile.audard@ecl20.ec-lyon.fr>
Date: Thu, 9 Nov 2023 11:29:02 +0100
Subject: [PATCH] Update mlp.py

---
 mlp.py | 17 ++++++++++++++++-
 1 file changed, 16 insertions(+), 1 deletion(-)

diff --git a/mlp.py b/mlp.py
index 1b7402f..8518643 100644
--- a/mlp.py
+++ b/mlp.py
@@ -58,7 +58,22 @@ def learn_once_cross_entropy(w1, b1, w2, b2, data, labels_train, learning_rate):
     targets_one_hot = one_hot(labels_train) # target as a one-hot encoding for the desired labels
     
     # cross-entropy loss
-    loss = 
+    loss = -np.sum(targets_one_hot * np.log(predictions)) / N
+    
+    # Backpropagation
+    d_z2 = a2 - targets_one_hot
+    d_w2 = np.dot(a1.T, d_z2) / N
+    d_b2 = d_z2 / N
+    d_a1 = np.dot(d_z2, w2.T)
+    d_z1 = d_a1 * z1 * (1 - a1)
+    d_w1 = np.dot(a0.T, d_z1) / N
+    d_b1 = d_z1 / N
+    
+    # Calculation of the updated weights and biases of the network with gradient descent method
+    w1 -= learning_rate * d_w1
+    w2 -= learning_rate * d_w2
+    b2 -= learning_rate * d_b2
+    b1 -= learning_rate * d_b1
     
     return w1, b1, w2, b2, loss
 
-- 
GitLab