Skip to content
Snippets Groups Projects
Commit 5d1fa2d1 authored by Audard Lucile's avatar Audard Lucile
Browse files

Update knn.py

parent 69c44ac4
No related branches found
No related tags found
No related merge requests found
......@@ -2,12 +2,36 @@ import numpy as np
def distance_matrix(mat1, mat2):
dists = np.sqrt(np.matmul(mat1, mat1)+ np.matmul(mat2, mat2) - 2 * np.matmul(mat1, mat2))
square1 = np.sum(np.square(mat1), axis = 1)
square2 = np.sum(np.square(mat2), axis = 1)
prod = np.dot(mat1, mat2.T)
dists = np.sqrt(square1 + square2 - 2 * prod)
return dists
def knn_predict(dists, labels_train, k):
return predicted_labels
# results matrix initialisation
predicted_labels = np.zeros(len(dists))
# loop on all the test images
for i in range(len(dists)):
# sort and keep the k shortest dists for test image i
sorted_dists = np.argsort(dists[i])
k_sorted_dists = sorted_dists[:k]
# get the matching labels_train
closest_labels = labels_train[k_sorted_dists]
# get the most common labels_train
predicted_labels[i] = np.argmax(closest_labels)
return np.array(predicted_labels)
def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
dists = distance_matrix(data_test, data_train)
tot = len(data_test)
accurate = 0
predicted_labels = knn_predict(dists, labels_train, k)
for i in range(tot):
if predicted_labels[i] == labels_test[i]:
accurate += 1
accuracy = accurate/tot
return accuracy
......@@ -15,22 +39,17 @@ def knn_predict(dists, labels_train, k):
mat1 = np.array([[1, 2],
[3, 4]])
mat2 = np.array([[5, 6],
[7, 8]])
if __name__ == "__main__":
A = np.matmul(mat1, mat1)
print(A)
bench_knn()
# data, labels = read_cifar.read_cifar('image-classification/data/cifar-10-batches-py')
# X_train, X_test, y_train, y_test = read_cifar.split_dataset(data, labels, 0.9)
# print(evaluate_knn(X_train, y_train, X_test, y_test, 5))
# print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
B = np.matmul(mat2, mat2)
print(B)
C = 2 * np.matmul(mat1, mat2)
print(C)
print(A + B - C)
mat = distance_matrix(mat1, mat2)
print(mat)
\ No newline at end of file
# y_test = []
# x_test = np.array([[1,2],[4,6]])
# x_train = np.array([[2,4],[7,2],[4,6]])
# y_train = [1,2,1]
# dist = distance_matrix(x_test,x_train)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment