Skip to content
Snippets Groups Projects
Commit 103122de authored by Denis Thomas's avatar Denis Thomas
Browse files

Upload New File

parent af94b139
No related branches found
No related tags found
No related merge requests found
knn.py 0 → 100644
import numpy as np
import matplotlib.pyplot as plt
from read_cifar import read_cifar, split_dataset
import os
# Question 1: Compute distance matrix
def distance_matrix(X1: np.ndarray, X2: np.ndarray) -> np.ndarray:
"""
Compute L2 Euclidean distance matrix between two matrices.
Using the formula: (a-b)^2 = a^2 + b^2 - 2ab
Args:
X1: First matrix of shape (n_samples_1, n_features)
X2: Second matrix of shape (n_samples_2, n_features)
Returns:
distances: Matrix of shape (n_samples_1, n_samples_2) containing
pairwise L2 distances
"""
# On calcul les norme carrée de nos vecteurs directement avec numpy
X1_norm = np.sum(X1**2, axis=1)
X2_norm = np.sum(X2**2, axis=1)
# On reshape pour pouvoir effectuer nos calculs matriciel directement
X1_norm = X1_norm.reshape(-1, 1) # Vecteur colonne
X2_norm = X2_norm.reshape(1, -1) # Vecteur ligne
# On calcul la disctance en utilisant direct la formule : (a-b)^2 = a^2 + b^2 - 2ab
distances = X1_norm + X2_norm - 2 * np.dot(X1, X2.T)
# On obtenait parfois des valeurs négative (surement a cause d'erreur numérique du calcul python)
distances = np.maximum(distances, 0)
return np.sqrt(distances)
# Question 2: KNN prediction
def knn_predict(dists: np.ndarray, labels_train: np.ndarray, k: int) -> np.ndarray:
num_test = dists.shape[0] # donne le nombre d'echantillons de test
predictions = np.zeros(num_test, dtype=np.int64) # on sotcke ici les predictions qu'on va faire
# On boucle sur les echantillons de test
for i in range(num_test):
# On récupere les k plus proches voisins direct avec argsort qui permet de chopper les indices qui permetterais un orgre croissant
k_nearest_indices = np.argsort(dists[i])[:k]
k_nearest_labels = labels_train[k_nearest_indices]
# grace a bincount on compte le nombre d'element de chaque classe, et on récupere avec argmax le majoritaire
predictions[i] = np.bincount(k_nearest_labels).argmax()
return predictions
# Question 3: Evaluate KNN classifier
def evaluate_knn(data_train: np.ndarray, labels_train: np.ndarray,
data_test: np.ndarray, labels_test: np.ndarray, k: int) -> float:
# On commence par calculer les distances que l'on place dans notre matrice dists
dists = distance_matrix(data_test, data_train)
# On fait ensuite les prédiction avec knn_predict
predictions = knn_predict(dists, labels_train, k)
# Finalement on calcul notre précision en regardant les elts bien classés
correct = 0
total = len(predictions)
for pred, true in zip(predictions, labels_test):
if pred == true:
correct += 1
accuracy = correct / total
return accuracy
# Question 4: On plot nos accuracy en fonction de k
def plot_accuracy_vs_k(data_train: np.ndarray, labels_train: np.ndarray,
data_test: np.ndarray, labels_test: np.ndarray,
k_values: list) -> None:
accuracies = []
# On boucle sur les valeurs de k choisit et on utilise notre fonction evaluate
for k in k_values:
accuracy = evaluate_knn(data_train, labels_train, data_test, labels_test, k)
accuracies.append(accuracy)
print(f"k={k}: accuracy={accuracy:.4f}")
# On creer notre plot et on le sauvegarde
plt.figure(figsize=(10, 6))
plt.plot(k_values, accuracies, 'bo-')
plt.xlabel('k (number of neighbors)')
plt.ylabel('Accuracy')
plt.title('KNN Classification Accuracy vs k')
plt.grid(True)
os.makedirs('results', exist_ok=True)
plt.savefig('results/knn.png')
plt.close()
if __name__ == "__main__":
# On charge les données CIFAR
cifar_dir = "data/cifar-10-batches-py"
all_data, all_labels = read_cifar(cifar_dir)
# On split en train et test
data_train, labels_train, data_test, labels_test = split_dataset(all_data, all_labels, split=0.9)
# On creer le plot pour étudier l impact de k sur la précision (k allant de 1 a 20)
k_values = list(range(1, 21))
plot_accuracy_vs_k(data_train, labels_train, data_test, labels_test, k_values)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment