import numpy as np
from utils.forward_pass import forward_pass

def adjust_weights_mse(a1, a2, w1, b1, w2, b2, data, targets, learning_rate):
    batch_size = data.shape[0]
    N_out = targets.shape[1]
    dCdA2 = 2 * (a2 - targets) / N_out
    dCdZ2 = dCdA2 * a2 * (1 - a2)
    dCdW2 = np.matmul(a1.T, dCdZ2) / batch_size 
    dCdB2 = np.sum(dCdZ2, axis=0, keepdims=True) / batch_size
    dCdA1 = np.matmul(dCdZ2, w2.T) 
    dCdZ1 = dCdA1 * a1 * (1 - a1)
    dCdW1 = np.matmul(data.T, dCdZ1) / batch_size
    dCdB1 = np.sum(dCdZ1, axis=0, keepdims=True) / batch_size

    w2 -= learning_rate * dCdW2
    w1 -= learning_rate * dCdW1
    b1 -= learning_rate * dCdB1
    b2 -= learning_rate * dCdB2
    return w1, b1, w2, b2


def learn_once_mse(w1: np.ndarray, b1: np.ndarray, w2: np.ndarray, b2: np.ndarray, data: np.ndarray, targets: np.ndarray, learning_rate: float):
    a1, a2 = forward_pass(w1, b1, w2, b2, data)
    loss = np.mean(np.square(a2 - targets))
    w1, b1, w2, b2 = adjust_weights_mse(a1, a2, w1, b1, w2, b2, data, targets, learning_rate)
    return w1, b1, w2, b2, loss