diff --git a/knn.py b/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..6a3ac343743e0ed17efc26e047ab43219262ef78 --- /dev/null +++ b/knn.py @@ -0,0 +1,106 @@ +import numpy as np +import matplotlib.pyplot as plt +from read_cifar import read_cifar, split_dataset +import os + +# Question 1: Compute distance matrix +def distance_matrix(X1: np.ndarray, X2: np.ndarray) -> np.ndarray: + """ + Compute L2 Euclidean distance matrix between two matrices. + Using the formula: (a-b)^2 = a^2 + b^2 - 2ab + + Args: + X1: First matrix of shape (n_samples_1, n_features) + X2: Second matrix of shape (n_samples_2, n_features) + + Returns: + distances: Matrix of shape (n_samples_1, n_samples_2) containing + pairwise L2 distances + """ + # On calcul les norme carrée de nos vecteurs directement avec numpy + X1_norm = np.sum(X1**2, axis=1) + X2_norm = np.sum(X2**2, axis=1) + + # On reshape pour pouvoir effectuer nos calculs matriciel directement + X1_norm = X1_norm.reshape(-1, 1) # Vecteur colonne + X2_norm = X2_norm.reshape(1, -1) # Vecteur ligne + + # On calcul la disctance en utilisant direct la formule : (a-b)^2 = a^2 + b^2 - 2ab + distances = X1_norm + X2_norm - 2 * np.dot(X1, X2.T) + + # On obtenait parfois des valeurs négative (surement a cause d'erreur numérique du calcul python) + distances = np.maximum(distances, 0) + + return np.sqrt(distances) + +# Question 2: KNN prediction +def knn_predict(dists: np.ndarray, labels_train: np.ndarray, k: int) -> np.ndarray: + num_test = dists.shape[0] # donne le nombre d'echantillons de test + predictions = np.zeros(num_test, dtype=np.int64) # on sotcke ici les predictions qu'on va faire + + # On boucle sur les echantillons de test + for i in range(num_test): + # On récupere les k plus proches voisins direct avec argsort qui permet de chopper les indices qui permetterais un orgre croissant + k_nearest_indices = np.argsort(dists[i])[:k] + k_nearest_labels = labels_train[k_nearest_indices] + + # grace a bincount on compte le nombre d'element de chaque classe, et on récupere avec argmax le majoritaire + predictions[i] = np.bincount(k_nearest_labels).argmax() + + return predictions + +# Question 3: Evaluate KNN classifier +def evaluate_knn(data_train: np.ndarray, labels_train: np.ndarray, + data_test: np.ndarray, labels_test: np.ndarray, k: int) -> float: + # On commence par calculer les distances que l'on place dans notre matrice dists + dists = distance_matrix(data_test, data_train) + + # On fait ensuite les prédiction avec knn_predict + predictions = knn_predict(dists, labels_train, k) + + # Finalement on calcul notre précision en regardant les elts bien classés + correct = 0 + total = len(predictions) + for pred, true in zip(predictions, labels_test): + if pred == true: + correct += 1 + accuracy = correct / total + + return accuracy + +# Question 4: On plot nos accuracy en fonction de k +def plot_accuracy_vs_k(data_train: np.ndarray, labels_train: np.ndarray, + data_test: np.ndarray, labels_test: np.ndarray, + k_values: list) -> None: + accuracies = [] + + # On boucle sur les valeurs de k choisit et on utilise notre fonction evaluate + for k in k_values: + accuracy = evaluate_knn(data_train, labels_train, data_test, labels_test, k) + accuracies.append(accuracy) + print(f"k={k}: accuracy={accuracy:.4f}") + + # On creer notre plot et on le sauvegarde + plt.figure(figsize=(10, 6)) + plt.plot(k_values, accuracies, 'bo-') + plt.xlabel('k (number of neighbors)') + plt.ylabel('Accuracy') + plt.title('KNN Classification Accuracy vs k') + plt.grid(True) + + os.makedirs('results', exist_ok=True) + + plt.savefig('results/knn.png') + plt.close() + +if __name__ == "__main__": + # On charge les données CIFAR + cifar_dir = "data/cifar-10-batches-py" + all_data, all_labels = read_cifar(cifar_dir) + + # On split en train et test + data_train, labels_train, data_test, labels_test = split_dataset(all_data, all_labels, split=0.9) + + # On creer le plot pour étudier l impact de k sur la précision (k allant de 1 a 20) + k_values = list(range(1, 21)) + plot_accuracy_vs_k(data_train, labels_train, data_test, labels_test, k_values)