import numpy as np
from read_cifar import read_cifar, split_dataset
import matplotlib.pyplot as plt
import os

# Question 10: Implémentation avec MSE
def learn_once_mse(w1, b1, w2, b2, data, targets, learning_rate):
    
    # On code la Forward pass: propagation des données à travers le réseau
    a0 = data  # Couche d'entrée
    z1 = np.matmul(a0, w1) + b1  # Première couche cachée (pré-activation)
    a1 = 1 / (1 + np.exp(-z1))   # Activation sigmoid
    z2 = np.matmul(a1, w2) + b2  # Couche de sortie (pré-activation)
    a2 = 1 / (1 + np.exp(-z2))   # Activation sigmoid finale
    predictions = a2

    # On calcul l'erreur quadratique moyenne
    loss = np.mean(np.square(predictions - targets))

    # On code ensuite la Backward pass: calcul des gradients et mise à jour des poids
    dC_da2 = 2 * (predictions - targets) / targets.shape[0]  # Dérivée par rapport à la sortie
    dC_dz2 = dC_da2 * (a2 * (1 - a2))  # Dérivée de la sigmoid
    dC_dw2 = np.matmul(a1.T, dC_dz2)   # Gradient pour w2
    dC_db2 = np.mean(dC_dz2, axis=0)   # Gradient pour b2

    # On propage l'erreur vers la première couche
    dC_da1 = np.matmul(dC_dz2, w2.T)
    dC_dz1 = dC_da1 * (a1 * (1 - a1))  # Dérivée de la sigmoid
    dC_dw1 = np.matmul(a0.T, dC_dz1)   # Gradient pour w1
    dC_db1 = np.mean(dC_dz1, axis=0)   # Gradient pour b1

    # On met à jour les poids et biais avec le gradient descent
    w1 = w1 - learning_rate * dC_dw1
    b1 = b1 - learning_rate * dC_db1
    w2 = w2 - learning_rate * dC_dw2
    b2 = b2 - learning_rate * dC_db2

    return w1, b1, w2, b2, loss

# Question 11: One-hot encoding
def one_hot(labels):
    n_classes = 10  # CIFAR-10 a 10 classes
    return np.eye(n_classes)[labels]

# Question 12: Implémentation avec la Cross-Entropy
def learn_once_cross_entropy(w1, b1, w2, b2, data, labels, learning_rate):
    batch_size = data.shape[0]
    
    # On implemente la Forward pass
    a0 = data
    z1 = np.matmul(a0, w1) + b1
    a1 = 1 / (1 + np.exp(-z1))
    z2 = np.matmul(a1, w2) + b2
    
    # On implemente notre sofmax
    z2 = z2 - np.max(z2, axis=1, keepdims=True)
    exp_z2 = np.exp(z2)
    a2 = exp_z2 / np.sum(exp_z2, axis=1, keepdims=True)
    
    # One-hot encoding
    y = one_hot(labels)
    
    # Cross entropy loss avec réduction correcte
    loss = -np.sum(y * np.log(a2 + 1e-15)) / batch_size
    
    # Backward pass avec normalisation correcte
    dC_dz2 = (a2 - y) / batch_size
    dC_dw2 = np.matmul(a1.T, dC_dz2)
    dC_db2 = np.mean(dC_dz2, axis=0)
    dC_da1 = np.matmul(dC_dz2, w2.T)
    dC_dz1 = dC_da1 * (a1 * (1 - a1))
    dC_dw1 = np.matmul(a0.T, dC_dz1)
    dC_db1 = np.mean(dC_dz1, axis=0)
    
    # Update weights and biases
    w1 = w1 - learning_rate * dC_dw1
    b1 = b1 - learning_rate * dC_db1
    w2 = w2 - learning_rate * dC_dw2
    b2 = b2 - learning_rate * dC_db2
    
    return w1, b1, w2, b2, loss
# Question 13: Training function
def train_mlp(w1, b1, w2, b2, data_train, labels_train, learning_rate, num_epoch):
    train_accuracies = []
    batch_size = 128  # Taille de batch raisonnable
    
    for epoch in range(num_epoch):
        # ON mélange aléatoirement les données à chaque époque
        indices = np.random.permutation(len(data_train))
        data_train = data_train[indices]
        labels_train = labels_train[indices]
        
        epoch_losses = []
        
        # On réalise l'entrainement par batch
        for i in range(0, len(data_train), batch_size):
            # On récupère notre batch de données + les labels
            batch_data = data_train[i:i+batch_size]
            batch_labels = labels_train[i:i+batch_size]
            
            # On réalise l'entrainement avec la fonction précédement définit sur le batch
            w1, b1, w2, b2, loss = learn_once_cross_entropy(
                w1, b1, w2, b2, batch_data, batch_labels, learning_rate)
            epoch_losses.append(loss)
        
        # On calcul finalement la précision de notre entrainement
        predictions = predict_mlp(w1, b1, w2, b2, data_train) # fonction définit en dessous
        # Pour l'accuracy on peut directement faire la moyenne du tableau contenant True si
        # la prédiction est vrai (si prediction = label) le mean sur [True,False] renvois 0.5
        accuracy = np.mean(predictions == labels_train)
        train_accuracies.append(accuracy)
        
        # On affiche l'évolution de l'entrainement dans la console
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epoch}, Loss: {np.mean(epoch_losses):.4f}, Accuracy: {accuracy:.4f}")
    
    return w1, b1, w2, b2, train_accuracies

def predict_mlp(w1, b1, w2, b2, data):
    # On Forward pass jusqu'à la prédiction (on avait des soucis sans le faire)
    z1 = np.matmul(data, w1) + b1
    a1 = 1 / (1 + np.exp(-z1))
    z2 = np.matmul(a1, w2) + b2
    exp_z2 = np.exp(z2 - np.max(z2, axis=1, keepdims=True))
    a2 = exp_z2 / np.sum(exp_z2, axis=1, keepdims=True)
    return np.argmax(a2, axis=1)

# Question 14: La fonction pour le test (comme a la fin de notre fonction train_mlp)
# On la mise aussi dans la fonction du dessus pour avoir un suivit de la loss
def test_mlp(w1, b1, w2, b2, data_test, labels_test):
    predictions = predict_mlp(w1, b1, w2, b2, data_test)
    return np.mean(predictions == labels_test)

# Question 15: La fonction pour l'entrainement complet
def run_mlp_training(data_train, labels_train, data_test, labels_test, d_h, learning_rate, num_epoch):
    # On initialise tailles, poids ...
    d_in = data_train.shape[1]
    d_out = 10  # CIFAR-10 classes
    w1 = 2 * np.random.rand(d_in, d_h) - 1
    b1 = np.zeros((1, d_h))
    w2 = 2 * np.random.rand(d_h, d_out) - 1
    b2 = np.zeros((1, d_out))
    
    # On passe a l'entrainement 
    w1, b1, w2, b2, train_accuracies = train_mlp(
        w1, b1, w2, b2, data_train, labels_train, learning_rate, num_epoch)
    
    # On passe ensuite au test
    test_accuracy = test_mlp(w1, b1, w2, b2, data_test, labels_test)
    return train_accuracies, test_accuracy

# Question 16: Fonction pour plot notre courbe d'entrainement
def plot_learning_curve(train_accuracies, test_accuracies):
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Training Accuracy')
    plt.plot(range(1, len(test_accuracies) + 1), test_accuracies, label='Test Accuracy', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('MLP Learning Curve')
    plt.legend()
    plt.grid(True)
    
    # On sauvegarder le graphique
    os.makedirs('results', exist_ok=True)
    plt.savefig('results/mlp.png')
    plt.close()


if __name__ == "__main__":
    # On charge les données
    cifar_dir = "data/cifar-10-batches-py"
    all_data, all_labels = read_cifar(cifar_dir)
    
    # On normalise les données (pixel codé sur 8 bit)
    all_data = all_data / 255.0
    
    # On split nos datas
    data_train, labels_train, data_test, labels_test = split_dataset(all_data, all_labels, split=0.9)
    
    # On initialise les paramètres, ils sont fournit dans le sujet
    d_h = 64
    learning_rate = 0.1
    num_epoch = 100
    
    # On démarre le training
    train_accuracies, test_accuracy = run_mlp_training(
        data_train, labels_train, data_test, labels_test,
        d_h, learning_rate, num_epoch
    )
    
    print(f"\nPrécision finale sur le test : {test_accuracy:.4f}")
    plot_learning_curve(train_accuracies, test_accuracy)