Skip to content
Snippets Groups Projects
Commit 4093c7bd authored by Danjou Pierre's avatar Danjou Pierre
Browse files

commi

parent fb477c85
No related branches found
No related tags found
No related merge requests found
...@@ -45,6 +45,26 @@ def evaluate_knn(data_train, labels_train, data_test, labels_tests, k): ...@@ -45,6 +45,26 @@ def evaluate_knn(data_train, labels_train, data_test, labels_tests, k):
accuracy = (labels_tests == result_test).sum() / N accuracy = (labels_tests == result_test).sum() / N
return(accuracy) return(accuracy)
def bench_knn():
k_indices = [i for i in range(20) if i!=0]
accuracies = []
# Loop on the k_indices to get all the accuracies
for k in k_indices:
accuracy = evaluate_knn(data_train, labels_train, data_test, labels_test, k)
accuracies.append(accuracy)
print(accuracy)
# Save and show the graph of accuracies
fig = plt.figure()
plt.plot(k_indices, accuracies)
plt.title("Accuracy as function of k")
plt.show()
plt.savefig(r'C:\Users\danjo\Documents\GitHub\image-classification\results')
return()
...@@ -55,14 +75,15 @@ if __name__ == "__main__": ...@@ -55,14 +75,15 @@ if __name__ == "__main__":
data, labels = read_cifar(main_path) data, labels = read_cifar(main_path)
data_train, data_test, labels_train, labels_test = split_dataset(data, labels, 0.9) data_train, data_test, labels_train, labels_test = split_dataset(data, labels, 0.9)
print(labels_test)
dists = distance_matrix(data_test, data_train) dists = distance_matrix(data_test, data_train)
#print(dists)
r = knn_predict(dists, labels_train, 10) r = knn_predict(dists, labels_train, 10)
accurancy = evaluate_knn(data_train, labels_train, data_test, labels_test, 10) accurancy = evaluate_knn(data_train, labels_train, data_test, labels_test, 10)
print(r)
print(accurancy) print(accurancy)
bench_knn()
results.png

2.34 KiB

results/knn.png

25.6 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment