diff --git a/knn.py b/knn.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a3ac343743e0ed17efc26e047ab43219262ef78
--- /dev/null
+++ b/knn.py
@@ -0,0 +1,106 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from read_cifar import read_cifar, split_dataset
+import os
+
+# Question 1: Compute distance matrix
+def distance_matrix(X1: np.ndarray, X2: np.ndarray) -> np.ndarray:
+    """
+    Compute L2 Euclidean distance matrix between two matrices.
+    Using the formula: (a-b)^2 = a^2 + b^2 - 2ab
+    
+    Args:
+        X1: First matrix of shape (n_samples_1, n_features)
+        X2: Second matrix of shape (n_samples_2, n_features)
+        
+    Returns:
+        distances: Matrix of shape (n_samples_1, n_samples_2) containing
+                  pairwise L2 distances
+    """
+    # On calcul les norme carrée de nos vecteurs directement avec numpy
+    X1_norm = np.sum(X1**2, axis=1)
+    X2_norm = np.sum(X2**2, axis=1)
+    
+    # On reshape pour pouvoir effectuer nos calculs matriciel directement
+    X1_norm = X1_norm.reshape(-1, 1) # Vecteur colonne
+    X2_norm = X2_norm.reshape(1, -1) # Vecteur ligne
+    
+    # On calcul la disctance en utilisant direct la formule : (a-b)^2 = a^2 + b^2 - 2ab
+    distances = X1_norm + X2_norm - 2 * np.dot(X1, X2.T)
+    
+    # On obtenait parfois des valeurs négative (surement a cause d'erreur numérique du calcul python)
+    distances = np.maximum(distances, 0)
+    
+    return np.sqrt(distances)
+
+# Question 2: KNN prediction
+def knn_predict(dists: np.ndarray, labels_train: np.ndarray, k: int) -> np.ndarray:
+    num_test = dists.shape[0] # donne le nombre d'echantillons de test
+    predictions = np.zeros(num_test, dtype=np.int64) # on sotcke ici les predictions qu'on va faire
+    
+    # On boucle sur les echantillons de test
+    for i in range(num_test):
+        # On récupere les k plus proches voisins direct avec argsort qui permet de chopper les indices qui permetterais un orgre croissant
+        k_nearest_indices = np.argsort(dists[i])[:k]
+        k_nearest_labels = labels_train[k_nearest_indices]
+        
+        # grace a bincount on compte le nombre d'element de chaque classe, et on récupere avec argmax le majoritaire 
+        predictions[i] = np.bincount(k_nearest_labels).argmax()
+    
+    return predictions
+
+# Question 3: Evaluate KNN classifier
+def evaluate_knn(data_train: np.ndarray, labels_train: np.ndarray, 
+                data_test: np.ndarray, labels_test: np.ndarray, k: int) -> float:
+    # On commence par calculer les distances que l'on place dans notre matrice dists
+    dists = distance_matrix(data_test, data_train)
+    
+    # On fait ensuite les prédiction avec knn_predict
+    predictions = knn_predict(dists, labels_train, k)
+    
+    # Finalement on calcul notre précision en regardant les elts bien classés
+    correct = 0
+    total = len(predictions)
+    for pred, true in zip(predictions, labels_test):
+        if pred == true:
+            correct += 1
+    accuracy = correct / total
+
+    return accuracy
+
+# Question 4: On plot nos accuracy en fonction de k
+def plot_accuracy_vs_k(data_train: np.ndarray, labels_train: np.ndarray,
+                      data_test: np.ndarray, labels_test: np.ndarray,
+                      k_values: list) -> None:
+    accuracies = []
+    
+    # On boucle sur les valeurs de k choisit et on utilise notre fonction evaluate
+    for k in k_values:
+        accuracy = evaluate_knn(data_train, labels_train, data_test, labels_test, k)
+        accuracies.append(accuracy)
+        print(f"k={k}: accuracy={accuracy:.4f}")
+    
+    # On creer notre plot et on le sauvegarde
+    plt.figure(figsize=(10, 6))
+    plt.plot(k_values, accuracies, 'bo-')
+    plt.xlabel('k (number of neighbors)')
+    plt.ylabel('Accuracy')
+    plt.title('KNN Classification Accuracy vs k')
+    plt.grid(True)
+    
+    os.makedirs('results', exist_ok=True)
+    
+    plt.savefig('results/knn.png')
+    plt.close()
+
+if __name__ == "__main__":
+    # On charge les données CIFAR
+    cifar_dir = "data/cifar-10-batches-py"
+    all_data, all_labels = read_cifar(cifar_dir)
+    
+    # On split en train et test 
+    data_train, labels_train, data_test, labels_test = split_dataset(all_data, all_labels, split=0.9)
+    
+    # On creer le plot pour étudier l impact de k sur la précision (k allant de 1 a 20)
+    k_values = list(range(1, 21))
+    plot_accuracy_vs_k(data_train, labels_train, data_test, labels_test, k_values)