diff --git a/mlp.py b/mlp.py
index 9e45e9c5650ebe1de51fce3a8450d38f92fca17d..38188f18290a7aa432e57167f5a65d09ba41f970 100644
--- a/mlp.py
+++ b/mlp.py
@@ -98,16 +98,12 @@ def learn_once_cross_entropy(
     """
 
     # Forward pass
-    a0 = data  # the data are the input of the first layer
-    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
-    a1 = 1 / (
-        1 + np.exp(-z1)
-    )  # output of the hidden layer (sigmoid activation function)
-    z2 = np.matmul(a1, w2) + b2  # input of the output layer
-    a2 = 1 / (
-        1 + np.exp(-z2)
-    )  # output of the output layer (sigmoid activation function)
-    predictions = a2  # the predicted values are the outputs of the output layer
+    a0 = data
+    z1 = np.matmul(a0, w1) + b1
+    a1 = 1 / (1 + np.exp(-z1))
+    z2 = np.matmul(a1, w2) + b2
+    a2 = np.exp(z2) / np.sum(np.exp(z2), axis=1, keepdims=True)
+    predictions = a2
 
     one_hot_targets = one_hot(labels_train)
 
@@ -171,7 +167,7 @@ def train_mlp(
 
     for _ in range(num_epoch):
         # Train once
-        w1, b1, w2, b2, _ = learn_once_mse(
+        w1, b1, w2, b2, _ = learn_once_cross_entropy(
             w1, b1, w2, b2, data_train, labels_train, learning_rate
         )
 
@@ -204,16 +200,12 @@ def test_mlp(
         float: The testing accuracy of the model on the given data.
     """
     # Forward pass
-    a0 = data_test  # the data are the input of the first layer
-    z1 = np.matmul(a0, w1) + b1  # input of the hidden layer
-    a1 = 1 / (
-        1 + np.exp(-z1)
-    )  # output of the hidden layer (sigmoid activation function)
-    z2 = np.matmul(a1, w2) + b2  # input of the output layer
-    a2 = 1 / (
-        1 + np.exp(-z2)
-    )  # output of the output layer (sigmoid activation function)
-    predictions = a2  # the predicted values are the outputs of the output layer
+    a0 = data_test
+    z1 = np.matmul(a0, w1) + b1
+    a1 = 1 / (1 + np.exp(-z1))
+    z2 = np.matmul(a1, w2) + b2
+    a2 = np.exp(z2) / np.sum(np.exp(z2), axis=1, keepdims=True)
+    predictions = a2
 
     # Compute accuracy
     accuracy = np.mean(np.argmax(predictions, axis=1) == labels_test)