import numpy as np
from utils.forward_pass import forward_pass
from utils.binary_cross_entropy import binary_cross_entropy

def adjust_weights_binary_cross_entropy(a1, a2, w1, b1, w2, b2, data, targets, learning_rate):
    batch_size = data.shape[0]
    dCdZ2 = a2 - targets
    dCdW2 = np.matmul(a1.T, dCdZ2)  / batch_size
    dCdB2 = np.sum(dCdZ2, axis=0, keepdims=True) / batch_size
    dCdA1 = np.matmul(dCdZ2, w2.T)
    dCdZ1 = dCdA1 * a1 * (1 - a1)
    dCdW1 = np.matmul(data.T, dCdZ1) / batch_size
    dCdB1 = np.sum(dCdZ1, axis=0, keepdims=True) / batch_size

    w2 -= learning_rate * dCdW2
    w1 -= learning_rate * dCdW1
    b1 -= learning_rate * dCdB1
    b2 -= learning_rate * dCdB2
    return w1, b1, w2, b2

def learn_once_cross_entropy(w1,b1,w2,b2,data,targets,learning_rate):
    a1, a2 = forward_pass(w1, b1, w2, b2, data)
    loss = binary_cross_entropy(a2, targets)
    w1, b1, w2, b2 = adjust_weights_binary_cross_entropy(a1, a2, w1, b1, w2, b2, data, targets, learning_rate)
    return w1, b1, w2, b2, loss