Skip to content
Snippets Groups Projects
Commit fb477c85 authored by Danjou Pierre's avatar Danjou Pierre
Browse files

Update knn.py

parent 76a15be7
Branches
No related tags found
No related merge requests found
...@@ -6,17 +6,7 @@ import matplotlib.pyplot as plt ...@@ -6,17 +6,7 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
def distance_matrix(A, B): def distance_matrix(A, B):
"""
Compute the L2 Euclidean distance matrix between two matrices A and B.
Parameters:
A (numpy.ndarray): Matrix of shape (m, n)
B (numpy.ndarray): Matrix of shape (p, n)
Returns:
numpy.ndarray: Distance matrix of shape (m, p) where the element (i, j) is the
Euclidean distance between A[i] and B[j].
"""
# Squared norms of each row in A and B # Squared norms of each row in A and B
A_squared = np.sum(A**2, axis=1).reshape(-1, 1) # Shape (m, 1) A_squared = np.sum(A**2, axis=1).reshape(-1, 1) # Shape (m, 1)
B_squared = np.sum(B**2, axis=1).reshape(1, -1) # Shape (1, p) B_squared = np.sum(B**2, axis=1).reshape(1, -1) # Shape (1, p)
...@@ -55,19 +45,7 @@ def evaluate_knn(data_train, labels_train, data_test, labels_tests, k): ...@@ -55,19 +45,7 @@ def evaluate_knn(data_train, labels_train, data_test, labels_tests, k):
accuracy = (labels_tests == result_test).sum() / N accuracy = (labels_tests == result_test).sum() / N
return(accuracy) return(accuracy)
# def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
# dists = distance_matrix(data_test, data_train)
# # Determine the number of images in data_test
# tot = len(data_test)
# accurate = 0
# predicted_labels = knn_predict(dists, labels_train, k)
# # Count the number of images in data_test whose label has been estimated correctly
# for i in range(tot):
# if predicted_labels[i] == labels_test[i]:
# accurate += 1
# # Calculate the classification rate
# accuracy = accurate/tot
# return accuracy
if __name__ == "__main__": if __name__ == "__main__":
...@@ -88,15 +66,3 @@ if __name__ == "__main__": ...@@ -88,15 +66,3 @@ if __name__ == "__main__":
print(accurancy) print(accurancy)
# data, labels = read_cifar('data\cifar-10-batches-py')
# data_train, data_test, labels_train, labels_test = split_dataset(data, labels, 0.9)
# k=3
# accurancies = []
# accurancy = evaluate_knn(data_train, data_test, labels_train, labels_test, k)
# accurancies.append(accurancy)
# print(accurancies)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment