Skip to content
Snippets Groups Projects
Commit 6c0a1bf6 authored by Audard Lucile's avatar Audard Lucile
Browse files

Update knn.py

parent 7a96330d
No related branches found
No related tags found
No related merge requests found
......@@ -43,35 +43,25 @@ def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
return accuracy
if __name__ == "__main__":
# Extraction of the data from Cifar database
data, labels = read_cifar("./data/cifar-10-batches-py")
data_train, labels_train, data_test, labels_test = split_dataset(data, labels, 0.9)
# Formatting the data into training and testing sets
data_train, labels_train, data_test, labels_test = split_dataset(data, labels, 0.1)
# Data to plot
k_list = [k for k in range(1, 21)]
accuracy = [evaluate_knn(data_train, labels_train, data_test, labels_test, k) for k in range (1, 21)]
# Plot the graph
plt.close()
plt.plot(k_list, accuracy)
plt.title("Variation of k-nearest neighbors method accuracy for k from 1 to 20")
plt.xlabel("k value")
plt.ylabel("Accuracy")
plt.grid(True, which='both')
plt.savefig("results/knn.png")
# x_test = np.array([[1,2],[4,6]])
# x_labels_test = np.array([0,1])
# x_train = np.array([[2,4],[7,2],[4,6]])
# x_labels_train = np.array([0,1,1])
# dist = distance_matrix(x_test, x_train)
# accuracy = evaluate_knn(x_train, x_labels_train, x_test, x_labels_test, 1)
# print(accuracy)
plt.show()
#plt.savefig("results/knn.png")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment