From 06f5ead5b4b3f8338176319e7f250852b234da48 Mon Sep 17 00:00:00 2001 From: Aya SAIDI <aya.saidi@auditeur.ec-lyon.fr> Date: Tue, 8 Nov 2022 01:22:38 +0100 Subject: [PATCH] Update knn.py --- knn.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/knn.py b/knn.py index 55b1f15..b99e3c2 100644 --- a/knn.py +++ b/knn.py @@ -1,6 +1,10 @@ #Libraries import numpy as np import matplotlib.pyplot as plt +import math +import random +from read_cifar import * + #Functions def distance_matrix(Y , X): #This function takes as parameters two matrices X and Y @@ -26,7 +30,7 @@ def knn_predict(dists, labels_train, k): def evaluate_knn(data_train, labels_train, data_test, labels_test, k): #This function evaluates the knn classifier rate - labels_test__pred=knn_predict(distance_matrix(data_train, data_test), labels_train, k) + labels_test_pred=knn_predict(distance_matrix(data_train, data_test), labels_train, k) num_samples= data_test.shape[0] num_correct= (labels_test == labels_test_pred).sum().item() accuracy= 100 * (num_correct / num_samples) #The accuracy is the percentage of the correctly predicted classes -- GitLab