Skip to content
Snippets Groups Projects
Commit efa3fceb authored by selalimi's avatar selalimi
Browse files

Update knn file

parent 6f171b98
No related merge requests found
...@@ -5,16 +5,15 @@ import matplotlib.pyplot as plt ...@@ -5,16 +5,15 @@ import matplotlib.pyplot as plt
import plotly.graph_objects as go import plotly.graph_objects as go
# Commentaire global expliquant le but du code
'''Here is the code to compute the L2 Euclidean distance matrix and predict labels using k-nearest neighbors:'''
# Create distance Matrix # Create distance Matrix
''' '''
Arguments: Arguments:
-Deux matrices. - Two matrices.
Returns: Returns:
dists : la matrice de distances euclidiennes L2. dists: the L2 Euclidean distance matrix.
La computation de cette fonction doit être effectuée uniquement avec des manipulations de matrices. The computation of this function should be done solely through matrix manipulations.
''' '''
def distance_matrix(X, Y): def distance_matrix(X, Y):
XX = np.sum(X ** 2, axis=1, keepdims=True) XX = np.sum(X ** 2, axis=1, keepdims=True)
...@@ -26,12 +25,12 @@ def distance_matrix(X, Y): ...@@ -26,12 +25,12 @@ def distance_matrix(X, Y):
# KNN predict # KNN predict
''' '''
Arguments: Arguments:
-dists : la matrice de distances entre l'ensemble d'entraînement et l'ensemble de test. - dists: the distance matrix between the training set and the test set.
-labels_train : les étiquettes d'entraînement. - labels_train: training labels.
- k : le nombre de voisins. - k: the number of neighbors.
Returns: Returns:
-Les étiquettes prédites pour les éléments de data_test. - Predicted labels for the elements in data_test.
''' '''
def knn_predict(dists, labels_train, k): def knn_predict(dists, labels_train, k):
n_test = dists.shape[0] n_test = dists.shape[0]
...@@ -46,15 +45,14 @@ def knn_predict(dists, labels_train, k): ...@@ -46,15 +45,14 @@ def knn_predict(dists, labels_train, k):
'''Here is the code to evaluate k-nearest neighbors and plot the accuracy as a function of k:''' '''Here is the code to evaluate k-nearest neighbors and plot the accuracy as a function of k:'''
''' '''
Arguments: Arguments:
-data_train : les données d'entraînement. - data_train: training data.
-labels_train : les étiquettes correspondantes. - labels_train: corresponding labels.
-data_test : les données de test. - data_test: test data.
-labels_test : les étiquettes correspondantes. - labels_test: corresponding labels.
-k : le nombre de voisins. - k: the number of neighbors.
Returns: Returns:
-La précision du modèle Knn : le taux de classification entre les valeurs prédites et les observations - Accuracy of the Knn model: the classification rate between predicted values and actual observations from test data.
réelles des données de test.
''' '''
def evaluate_knn(data_train, labels_train, data_test, labels_test, k): def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
dists = distance_matrix(data_test, data_train) dists = distance_matrix(data_test, data_train)
...@@ -63,13 +61,13 @@ def evaluate_knn(data_train, labels_train, data_test, labels_test, k): ...@@ -63,13 +61,13 @@ def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
return accuracy return accuracy
# Plot Accuracy of KNN model # Plot Accuracy of KNN model
'''The function plots the variation of accuracy with the number of neighbors K.'''
''' '''
******La fonction trace la variation de la précision en fonction du nombre de voisins K****
Arguments: Arguments:
-X_train : données d'entraînement - X_train: training data.
-y_train : étiquettes d'entraînement - y_train: training labels.
-X_test : données de test - X_test: test data.
-y_test : étiquettes de test - y_test: test labels.
''' '''
def plot_KNN(X_train, y_train, X_test, y_test, max_k=20): def plot_KNN(X_train, y_train, X_test, y_test, max_k=20):
neighbors = np.arange(1, max_k + 1) neighbors = np.arange(1, max_k + 1)
...@@ -79,4 +77,3 @@ def plot_KNN(X_train, y_train, X_test, y_test, max_k=20): ...@@ -79,4 +77,3 @@ def plot_KNN(X_train, y_train, X_test, y_test, max_k=20):
plt.ylabel('Accuracy') plt.ylabel('Accuracy')
plt.title('Variation of Accuracy with K') plt.title('Variation of Accuracy with K')
plt.savefig("Results/knn.png") plt.savefig("Results/knn.png")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment