Skip to content
Snippets Groups Projects
Commit 2068610f authored by Dubray Chloe's avatar Dubray Chloe
Browse files

Upload New File

parent 75cb3981
No related branches found
No related tags found
No related merge requests found
mlp.py 0 → 100644
import numpy as np
from read_cifar import *
import matplotlib.pyplot as plt
d_h = 64
epsilon = 0.00001
def sigma (z) :
return 1 / (1 + np.exp(-z))
def learn_one_mse (w1, b1, w2, b2, data, targets, learning_rate) :
a0 = data
z1 = np.matmul(a0,w1) + b1
a1 = sigma (z1)
z2 = np.matmul(a1, w2) + b2
a2 = sigma(z2)
predictions = a2
d_out = np.shape(targets)[1]
dC_da2 = 2*(a2-targets)/d_out
dC_dz2 = dC_da2 * a2 * (1-a2)
dC_dw2 = np.matmul(a1.T, dC_dz2)
dC_db2 = dC_dz2.mean(axis=0)
dC_da1 = np.matmul(dC_da2,w2.T)
dC_dz1 = a1 * (1-a1) * dC_da1
dC_dw1 = np.matmul(a0.T,dC_dz1)
dC_db1 = dC_dz1.mean(axis=0)
w1 -= learning_rate * dC_dw1
b1 -= learning_rate * dC_db1
w2 -= learning_rate * dC_dw2
b2 -= learning_rate * dC_db2
loss = np.mean(np.square(predictions - targets))
return (w1, b1, w2, b2, loss)
def one_hot(labels, num_classes=None):
if num_classes is None:
num_classes = np.max(labels) + 1
one_hot_matrix = np.zeros((len(labels), num_classes), dtype=int)
for i in range(len(labels)) :
one_hot_matrix[i,labels[i]]=1
return one_hot_matrix
def learn_one_cross_entropy (w1, b1, w2, b2, data, labels_train, learning_rate) :
a0 = data
z1 = np.matmul(a0,w1) + b1
a1 = sigma (z1)
z2 = np.matmul(a1, w2) + b2
a2 = sigma(z2)
predictions = a2
N = len(labels_train)
y = one_hot (labels_train)
dC_dz2 = a2 - y
dC_dw2 = np.matmul(a1.T, dC_dz2)
dC_db2 = dC_dz2.mean(axis=0)
dC_da1 = np.matmul(dC_dz2,w2.T)
dC_dz1 = a1 * (1-a1) * dC_da1
dC_dw1 = np.matmul(a0.T,dC_dz1)
dC_db1 = dC_dz1.mean(axis=0)
w1 -= learning_rate * dC_dw1
b1 -= learning_rate * dC_db1
w2 -= learning_rate * dC_dw2
b2 -= learning_rate * dC_db2
loss = -np.sum(y * np.log2(predictions + epsilon) + (1 - y) * np.log2(1 - predictions + epsilon)) / N
return (w1, b1, w2, b2, loss)
def predict_class(predictions):
return np.argmax(predictions, axis=1)
def accuracy(y_true, y_pred):
return np.mean(y_true == y_pred)
def train_mlp (w1, b1, w2, b2, data_train, labels_train, learning_rate, num_epoch) :
training_accuracies = []
for k in range(num_epoch) :
w1_new, b1_new, w2_new, b2_new, loss = learn_one_cross_entropy (w1, b1, w2, b2, data_train, labels_train, learning_rate)
a0 = data_train
z1 = np.matmul(a0,w1) + b1
a1 = sigma (z1)
z2 = np.matmul(a1, w2) + b2
a2 = sigma(z2)
predictions = predict_class(a2)
training_accuracy = accuracy(labels_train, predictions)
training_accuracies.append(training_accuracy)
w1, b1, w2, b2 = w1_new, b1_new, w2_new, b2_new
return (w1, b1, w2, b2, training_accuracies)
def test_mlp (w1, b1, w2, b2, data_test, labels_test) :
a0 = data_test
z1 = np.matmul(a0,w1) + b1
a1 = sigma (z1)
z2 = np.matmul(a1, w2) + b2
a2 = sigma(z2)
predictions = predict_class(a2)
test_accuracy = accuracy(labels_test, predictions)
return (test_accuracy)
def run_mlp_training (data_train, labels_train, data_test, labels_test, d_h :int, learning_rate, num_epoch) :
d_in = (np.shape(data_train))[1]
d_out = np.max(labels_train)+1
w1 = 2 * np.random.rand(d_in, 64) - 1
b1 = np.zeros((1, 64))
w2 = 2 * np.random.rand(64, d_out) - 1
b2 = np.zeros((1, d_out))
W1, B1, W2, B2, training_accuracies = train_mlp (w1, b1, w2, b2, data_train, labels_train, learning_rate, num_epoch)
final_testing_accuracy = test_mlp(W1, B1, W2, B2, data_test, labels_test)
return training_accuracies, final_testing_accuracy
if __name__ == "__main__":
learning_rate = 0.1
num_epoch = 100
split_factor=0.9
batch_dir = 'data/cifar-10-batches-py/'
data, labels = read_cifar(batch_dir)
data_train, labels_train, data_test, labels_test = split_dataset (data, labels, split_factor)
training_accuracies, final_testing_accuracy = run_mlp_training (data_train, labels_train, data_test, d_h, labels_test, learning_rate, num_epoch)
k = list(range(num_epoch))
k = [x+1 for x in k]
plt.plot(k, training_accuracies)
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment