Skip to content
Snippets Groups Projects
Commit a60e1f05 authored by Audard Lucile's avatar Audard Lucile
Browse files

Update mlp.py

parent 98dc5f38
Branches
No related tags found
No related merge requests found
...@@ -78,5 +78,49 @@ def learn_once_cross_entropy(w1, b1, w2, b2, data, labels_train, learning_rate): ...@@ -78,5 +78,49 @@ def learn_once_cross_entropy(w1, b1, w2, b2, data, labels_train, learning_rate):
return w1, b1, w2, b2, loss return w1, b1, w2, b2, loss
def train_mlp(w1, b1, w2, b2, data_train, labels_train, learning_rate, num_epoch):
train_accuracies = [0] * num_epoch
for i in range(num_epoch):
w1, b1, w2, b2, loss = learn_once_cross_entropy(w1, b1, w2, b2, data_train, labels_train, learning_rate)
# Forward pass
a0 = data_train # the data are the input of the first layer
z1 = np.matmul(a0, w1) + b1 # input of the hidden layer
a1 = 1 / (1 + np.exp(-z1)) # output of the hidden layer (sigmoid activation function)
z2 = np.matmul(a1, w2) + b2 # input of the output layer
a2 = 1 / (1 + np.exp(-z2)) # output of the output layer (sigmoid activation function)
predictions = a2 # the predicted values are the outputs of the output layer
# Find the predicted class
prediction = np.argmax(predictions, axis = 1)
# Calculate the accuracy
accuracy = np.mean(labels_train == prediction)
train_accuracies[i] = accuracy
return w1, b1, w2, b2, train_accuracies
def test_mlp(w1, b1, w2, b2, data_test, labels_test):
return test_accuracy
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment