Skip to content
Snippets Groups Projects
Commit 06f5ead5 authored by Saidi Aya's avatar Saidi Aya
Browse files

Update knn.py

parent f98665d7
No related branches found
No related tags found
No related merge requests found
#Libraries #Libraries
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import math
import random
from read_cifar import *
#Functions #Functions
def distance_matrix(Y , X): def distance_matrix(Y , X):
#This function takes as parameters two matrices X and Y #This function takes as parameters two matrices X and Y
...@@ -26,7 +30,7 @@ def knn_predict(dists, labels_train, k): ...@@ -26,7 +30,7 @@ def knn_predict(dists, labels_train, k):
def evaluate_knn(data_train, labels_train, data_test, labels_test, k): def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
#This function evaluates the knn classifier rate #This function evaluates the knn classifier rate
labels_test__pred=knn_predict(distance_matrix(data_train, data_test), labels_train, k) labels_test_pred=knn_predict(distance_matrix(data_train, data_test), labels_train, k)
num_samples= data_test.shape[0] num_samples= data_test.shape[0]
num_correct= (labels_test == labels_test_pred).sum().item() num_correct= (labels_test == labels_test_pred).sum().item()
accuracy= 100 * (num_correct / num_samples) #The accuracy is the percentage of the correctly predicted classes accuracy= 100 * (num_correct / num_samples) #The accuracy is the percentage of the correctly predicted classes
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment