Skip to content
Snippets Groups Projects
Select Git revision
  • master
  • main default protected
2 results

knn.py

Blame
  • Loris's avatar
    Duperret Loris authored
    7612de50
    History
    knn.py 1.27 KiB
    import numpy as np
    from sklearn.metrics import accuracy_score
    
    
    
    def distance_matrix(mat1, mat2):
        norms1 = np.sum(mat1**2, axis=1, keepdims=True)
        norms2 = np.sum(mat2**2, axis=1, keepdims=True)
    
        dot_product = np.dot(mat1, mat2.T)
    
        dists = np.sqrt(norms1 - 2 * dot_product + norms2.T)
        return dists
    
    def knn_predict(dists, labels_train, k):
        num_test_samples = dists.shape[0]
    
        pred_labels = np.zeros(num_test_samples, dtype=labels_train.dtype)
    
        for i in range(num_test_samples):
            distances = dists[i]
    
            # On trouve les k indices avec la distance minimale
            k_nearest_indices = np.argsort(distances)[:k]
    
            # On récupère les labels de ces voisins
            k_nearest_labels = labels_train[k_nearest_indices]
    
            # On compte les occurences des labels et on choisit celui qui apparait le plus
            pred_label = np.argmax(np.bincount(k_nearest_labels))
    
            pred_labels[i] = pred_label
    
        return pred_labels
    
    
    
    def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
        predicted_labels = knn_predict(distance_matrix(data_test, data_train), labels_train, k)
    
        # Calcule la précision grâce à la prediction et à la valeur réelle
        accuracy = accuracy_score(labels_test, predicted_labels)
    
        return accuracy