Select Git revision
knn.py 2.50 KiB
import numpy as np
import matplotlib.pyplot as plt
import os
def distance_matrix(A, B):
A_square = np.sum(np.square(A), axis=1)
B_square = np.sum(np.square(B), axis=1)
A_2 = A_square[:, None]
B_2 = B_square[None, :]
dists = np.sqrt(A_2 + B_2 - 2 * np.dot(A, B.T))
return dists
def knn_predict(dists, labels_train, k):
num_test = dists.shape[0]
predicted_labels = np.zeros(num_test, dtype=int)
for i in range(num_test):
# Find the indices of the k-nearest neighbors for the i-th test example
nearest_neighbor_indices = np.argsort(dists[i])[:k]
# Get the labels of the k-nearest neighbors
k_nearest_labels = labels_train[nearest_neighbor_indices]
# Count the occurrences of each label and select the most common one
unique_labels, counts = np.unique(k_nearest_labels, return_counts=True)
most_common_label = unique_labels[np.argmax(counts)]
predicted_labels[i] = most_common_label
return predicted_labels
def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
# Compute the distance matrix
dists = distance_matrix(data_train, data_test)
# Predict labels for the test data using k-nearest neighbors
predicted_labels = knn_predict(dists, labels_train, k)
# Calculate accuracy
y_pred = knn_predict(dists, labels_train, k)
accuracy = np.mean(y_pred == labels_test)
return accuracy
def plot_accuracy_vs_k(data_train, labels_train, data_test, labels_test, split_factor=0.9):
k_values = list(range(1, 21))
accuracies = []
for k in k_values:
accuracy = evaluate_knn(data_train, labels_train, data_test, labels_test, k)
accuracies.append(accuracy)
# Create the "results" directory if it doesn't exist
os.makedirs("results", exist_ok=True)
plt.plot(k_values, accuracies)
plt.xlabel('k')
plt.ylabel('Accuracy')
plt.title('Accuracy vs. k for KNN')
plt.grid(True)
plt.savefig('results/knn.png')
plt.show()
if __name__ == "__main__":
# Load your data and split it into data_train, labels_train, data_test, and labels_test
data_train = np.random.rand(100, 2) # Replace with your actual data
labels_train = np.random.randint(0, 2, 100) # Replace with your actual labels
# Generate test data and labels with the same number of samples as data_train
data_test = np.random.rand(100, 2)
labels_test = np.random.randint(0, 2, 100)
plot_accuracy_vs_k(data_train, labels_train, data_test, labels_test)