import numpy as np
import matplotlib.pyplot as plt
from read_cifar import read_cifar, split_dataset
import os

# Question 1: Compute distance matrix
def distance_matrix(X1: np.ndarray, X2: np.ndarray) -> np.ndarray:
    """
    Compute L2 Euclidean distance matrix between two matrices.
    Using the formula: (a-b)^2 = a^2 + b^2 - 2ab
    
    Args:
        X1: First matrix of shape (n_samples_1, n_features)
        X2: Second matrix of shape (n_samples_2, n_features)
        
    Returns:
        distances: Matrix of shape (n_samples_1, n_samples_2) containing
                  pairwise L2 distances
    """
    # On calcul les norme carrée de nos vecteurs directement avec numpy
    X1_norm = np.sum(X1**2, axis=1)
    X2_norm = np.sum(X2**2, axis=1)
    
    # On reshape pour pouvoir effectuer nos calculs matriciel directement
    X1_norm = X1_norm.reshape(-1, 1) # Vecteur colonne
    X2_norm = X2_norm.reshape(1, -1) # Vecteur ligne
    
    # On calcul la disctance en utilisant direct la formule : (a-b)^2 = a^2 + b^2 - 2ab
    distances = X1_norm + X2_norm - 2 * np.dot(X1, X2.T)
    
    # On obtenait parfois des valeurs négative (surement a cause d'erreur numérique du calcul python)
    distances = np.maximum(distances, 0)
    
    return np.sqrt(distances)

# Question 2: KNN prediction
def knn_predict(dists: np.ndarray, labels_train: np.ndarray, k: int) -> np.ndarray:
    num_test = dists.shape[0] # donne le nombre d'echantillons de test
    predictions = np.zeros(num_test, dtype=np.int64) # on sotcke ici les predictions qu'on va faire
    
    # On boucle sur les echantillons de test
    for i in range(num_test):
        # On récupere les k plus proches voisins direct avec argsort qui permet de chopper les indices qui permetterais un orgre croissant
        k_nearest_indices = np.argsort(dists[i])[:k]
        k_nearest_labels = labels_train[k_nearest_indices]
        
        # grace a bincount on compte le nombre d'element de chaque classe, et on récupere avec argmax le majoritaire 
        predictions[i] = np.bincount(k_nearest_labels).argmax()
    
    return predictions

# Question 3: Evaluate KNN classifier
def evaluate_knn(data_train: np.ndarray, labels_train: np.ndarray, 
                data_test: np.ndarray, labels_test: np.ndarray, k: int) -> float:
    # On commence par calculer les distances que l'on place dans notre matrice dists
    dists = distance_matrix(data_test, data_train)
    
    # On fait ensuite les prédiction avec knn_predict
    predictions = knn_predict(dists, labels_train, k)
    
    # Finalement on calcul notre précision en regardant les elts bien classés
    correct = 0
    total = len(predictions)
    for pred, true in zip(predictions, labels_test):
        if pred == true:
            correct += 1
    accuracy = correct / total

    return accuracy

# Question 4: On plot nos accuracy en fonction de k
def plot_accuracy_vs_k(data_train: np.ndarray, labels_train: np.ndarray,
                      data_test: np.ndarray, labels_test: np.ndarray,
                      k_values: list) -> None:
    accuracies = []
    
    # On boucle sur les valeurs de k choisit et on utilise notre fonction evaluate
    for k in k_values:
        accuracy = evaluate_knn(data_train, labels_train, data_test, labels_test, k)
        accuracies.append(accuracy)
        print(f"k={k}: accuracy={accuracy:.4f}")
    
    # On creer notre plot et on le sauvegarde
    plt.figure(figsize=(10, 6))
    plt.plot(k_values, accuracies, 'bo-')
    plt.xlabel('k (number of neighbors)')
    plt.ylabel('Accuracy')
    plt.title('KNN Classification Accuracy vs k')
    plt.grid(True)
    
    os.makedirs('results', exist_ok=True)
    
    plt.savefig('results/knn.png')
    plt.close()

if __name__ == "__main__":
    # On charge les données CIFAR
    cifar_dir = "data/cifar-10-batches-py"
    all_data, all_labels = read_cifar(cifar_dir)
    
    # On split en train et test 
    data_train, labels_train, data_test, labels_test = split_dataset(all_data, all_labels, split=0.9)
    
    # On creer le plot pour étudier l impact de k sur la précision (k allant de 1 a 20)
    k_values = list(range(1, 21))
    plot_accuracy_vs_k(data_train, labels_train, data_test, labels_test, k_values)