Skip to content
Snippets Groups Projects
Commit efa3fceb authored by selalimi's avatar selalimi
Browse files

Update knn file

parent 6f171b98
Branches
No related tags found
No related merge requests found
......@@ -5,16 +5,15 @@ import matplotlib.pyplot as plt
import plotly.graph_objects as go
# Commentaire global expliquant le but du code
'''Here is the code to compute the L2 Euclidean distance matrix and predict labels using k-nearest neighbors:'''
# Create distance Matrix
'''
Arguments:
-Deux matrices.
- Two matrices.
Returns:
dists : la matrice de distances euclidiennes L2.
La computation de cette fonction doit être effectuée uniquement avec des manipulations de matrices.
dists: the L2 Euclidean distance matrix.
The computation of this function should be done solely through matrix manipulations.
'''
def distance_matrix(X, Y):
XX = np.sum(X ** 2, axis=1, keepdims=True)
......@@ -26,12 +25,12 @@ def distance_matrix(X, Y):
# KNN predict
'''
Arguments:
-dists : la matrice de distances entre l'ensemble d'entraînement et l'ensemble de test.
-labels_train : les étiquettes d'entraînement.
- k : le nombre de voisins.
- dists: the distance matrix between the training set and the test set.
- labels_train: training labels.
- k: the number of neighbors.
Returns:
-Les étiquettes prédites pour les éléments de data_test.
- Predicted labels for the elements in data_test.
'''
def knn_predict(dists, labels_train, k):
n_test = dists.shape[0]
......@@ -46,15 +45,14 @@ def knn_predict(dists, labels_train, k):
'''Here is the code to evaluate k-nearest neighbors and plot the accuracy as a function of k:'''
'''
Arguments:
-data_train : les données d'entraînement.
-labels_train : les étiquettes correspondantes.
-data_test : les données de test.
-labels_test : les étiquettes correspondantes.
-k : le nombre de voisins.
- data_train: training data.
- labels_train: corresponding labels.
- data_test: test data.
- labels_test: corresponding labels.
- k: the number of neighbors.
Returns:
-La précision du modèle Knn : le taux de classification entre les valeurs prédites et les observations
réelles des données de test.
- Accuracy of the Knn model: the classification rate between predicted values and actual observations from test data.
'''
def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
dists = distance_matrix(data_test, data_train)
......@@ -63,13 +61,13 @@ def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
return accuracy
# Plot Accuracy of KNN model
'''The function plots the variation of accuracy with the number of neighbors K.'''
'''
******La fonction trace la variation de la précision en fonction du nombre de voisins K****
Arguments:
-X_train : données d'entraînement
-y_train : étiquettes d'entraînement
-X_test : données de test
-y_test : étiquettes de test
- X_train: training data.
- y_train: training labels.
- X_test: test data.
- y_test: test labels.
'''
def plot_KNN(X_train, y_train, X_test, y_test, max_k=20):
neighbors = np.arange(1, max_k + 1)
......@@ -79,4 +77,3 @@ def plot_KNN(X_train, y_train, X_test, y_test, max_k=20):
plt.ylabel('Accuracy')
plt.title('Variation of Accuracy with K')
plt.savefig("Results/knn.png")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment