Skip to content
Snippets Groups Projects
Commit 562e7af5 authored by Duperret Loris's avatar Duperret Loris
Browse files

Delete knn.py

parent efe8297d
Branches
No related tags found
No related merge requests found
import numpy as np
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
def distance_matrix(matrix1, matrix2):
# Calculate the squared norms of each row in the input matrices
norms1 = np.sum(matrix1**2, axis=1, keepdims=True)
norms2 = np.sum(matrix2**2, axis=1, keepdims=True)
# Compute the dot product between the matrices
dot_product = np.dot(matrix1, matrix2.T)
# Calculate the L2 Euclidean distance using the hint formula
dists = np.sqrt(norms1 - 2 * dot_product + norms2.T)
return dists
def knn_predict(dists, labels_train, k):
# Number of test samples
num_test_samples = dists.shape[0]
# Initialize an array to store the predicted labels
predicted_labels = np.zeros(num_test_samples, dtype=labels_train.dtype)
for i in range(num_test_samples):
# Get the distances for the current test sample
distances = dists[i]
# Find the indices of the k nearest neighbors
k_nearest_indices = np.argsort(distances)[:k]
# Get the labels of the k nearest neighbors
k_nearest_labels = labels_train[k_nearest_indices]
# Use np.bincount to count the occurrences of each label
# and choose the label with the highest count
predicted_label = np.argmax(np.bincount(k_nearest_labels))
# Assign the predicted label to the current test sample
predicted_labels[i] = predicted_label
return predicted_labels
def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
# Use the previously defined knn_predict function to get predictions
predicted_labels = knn_predict(distance_matrix(data_test, data_train), labels_train, k)
# Calculate the accuracy by comparing predicted labels to actual labels
accuracy = accuracy_score(labels_test, predicted_labels)
return accuracy
split_factor = 0.9
k_values = range(1, 21)
accuracies = []
for k in k_values:
accuracy = evaluate_knn(data_train, labels_train, data_test, labels_test, k)
accuracies.append(accuracy)
# Create the plot
plt.figure(figsize=(8, 6))
plt.plot(k_values, accuracies, marker='o')
plt.title('KNN Accuracy vs. k')
plt.xlabel('k')
plt.ylabel('Accuracy')
plt.grid(True)
# Save the plot as "knn.png" in the "results" directory
plt.savefig('results/knn.png')
# Show the plot (optional)
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment