Skip to content
Snippets Groups Projects
Commit ce26358c authored by BaptisteBrd's avatar BaptisteBrd
Browse files

knn final

parent 2418acfd
No related branches found
No related tags found
No related merge requests found
import numpy as np import numpy as np
import read_cifar
import matplotlib.pyplot as plt
def distance_matrix(a,b): def distance_matrix(a,b):
sum_a = np.sum(a**2, axis=1, keepdims=True) sum_a = np.sum(a**2, axis=1, keepdims=True)
...@@ -27,12 +29,43 @@ def knn_predict(dists, labels_train, k): ...@@ -27,12 +29,43 @@ def knn_predict(dists, labels_train, k):
return(np.array(predicted_labels)) return(np.array(predicted_labels))
def evaluate_knn(data_train, labels_train, data_test, labels_test, k): def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
rate = 0
dist_train_test = distance_matrix(data_train, data_test)
prediction = knn_predict(dist_train_test, labels_train, k)
for j in range(len(prediction)):
if prediction[j]==labels_test[j]:
rate +=1
rate = rate/len(prediction)
return rate
def knn_final():
range_k = range(1,20)
rates = []
data,labels = read_cifar.read_cifar("data/cifar-10-batches-py")
data_train_f, labels_train_f, data_test_f, labels_test_f = read_cifar.split_dataset(data, labels, 0.9)
for k in range_k :
rate_k = evaluate_knn(data_train_f, labels_train_f, data_test_f, labels_test_f, k)
rates.append(rate_k)
plt.figure(figsize=(10, 7))
plt.xlabel('k')
plt.ylabel('Accuracy rate')
plt.plot(range_k, rates)
plt.title("Accuracy rate = f(k)")
plt.legend()
plt.grid(True)
plt.show()
if __name__ == "__main__" : if __name__ == "__main__" :
a1 = np.array([[0,0,1],[0,0,0],[1,1,2]])
b1 = np.array([[1,3,1], [1,1,4], [1,5,1]]) knn_final()
print(distance_matrix(a1,b1)) #a1 = np.array([[0,0,1],[0,0,0],[1,1,2]])
\ No newline at end of file #b1 = np.array([[1,3,1], [1,1,4], [1,5,1]])
#print(distance_matrix(a1,b1))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment