Skip to content
Snippets Groups Projects
Commit efa3fceb authored by selalimi's avatar selalimi
Browse files

Update knn file

parent 6f171b98
No related merge requests found
......@@ -5,16 +5,15 @@ import matplotlib.pyplot as plt
import plotly.graph_objects as go
# Commentaire global expliquant le but du code
'''Here is the code to compute the L2 Euclidean distance matrix and predict labels using k-nearest neighbors:'''
# Create distance Matrix
'''
Arguments:
-Deux matrices.
- Two matrices.
Returns:
dists : la matrice de distances euclidiennes L2.
La computation de cette fonction doit être effectuée uniquement avec des manipulations de matrices.
dists: the L2 Euclidean distance matrix.
The computation of this function should be done solely through matrix manipulations.
'''
def distance_matrix(X, Y):
XX = np.sum(X ** 2, axis=1, keepdims=True)
......@@ -26,12 +25,12 @@ def distance_matrix(X, Y):
# KNN predict
'''
Arguments:
-dists : la matrice de distances entre l'ensemble d'entraînement et l'ensemble de test.
-labels_train : les étiquettes d'entraînement.
- k : le nombre de voisins.
- dists: the distance matrix between the training set and the test set.
- labels_train: training labels.
- k: the number of neighbors.
Returns:
-Les étiquettes prédites pour les éléments de data_test.
- Predicted labels for the elements in data_test.
'''
def knn_predict(dists, labels_train, k):
n_test = dists.shape[0]
......@@ -46,15 +45,14 @@ def knn_predict(dists, labels_train, k):
'''Here is the code to evaluate k-nearest neighbors and plot the accuracy as a function of k:'''
'''
Arguments:
-data_train : les données d'entraînement.
-labels_train : les étiquettes correspondantes.
-data_test : les données de test.
-labels_test : les étiquettes correspondantes.
-k : le nombre de voisins.
- data_train: training data.
- labels_train: corresponding labels.
- data_test: test data.
- labels_test: corresponding labels.
- k: the number of neighbors.
Returns:
-La précision du modèle Knn : le taux de classification entre les valeurs prédites et les observations
réelles des données de test.
- Accuracy of the Knn model: the classification rate between predicted values and actual observations from test data.
'''
def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
dists = distance_matrix(data_test, data_train)
......@@ -63,13 +61,13 @@ def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
return accuracy
# Plot Accuracy of KNN model
'''The function plots the variation of accuracy with the number of neighbors K.'''
'''
******La fonction trace la variation de la précision en fonction du nombre de voisins K****
Arguments:
-X_train : données d'entraînement
-y_train : étiquettes d'entraînement
-X_test : données de test
-y_test : étiquettes de test
- X_train: training data.
- y_train: training labels.
- X_test: test data.
- y_test: test labels.
'''
def plot_KNN(X_train, y_train, X_test, y_test, max_k=20):
neighbors = np.arange(1, max_k + 1)
......@@ -79,4 +77,3 @@ def plot_KNN(X_train, y_train, X_test, y_test, max_k=20):
plt.ylabel('Accuracy')
plt.title('Variation of Accuracy with K')
plt.savefig("Results/knn.png")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment