From 19373e897dcbf3c71b878864a3d1c40c8dff2a7a Mon Sep 17 00:00:00 2001 From: BaptisteBrd <75663738+BaptisteBrd@users.noreply.github.com> Date: Fri, 10 Nov 2023 17:33:00 +0100 Subject: [PATCH] exchange data test and train --- knn.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/knn.py b/knn.py index 2833ea7..9e44028 100644 --- a/knn.py +++ b/knn.py @@ -2,11 +2,16 @@ import numpy as np import read_cifar import matplotlib.pyplot as plt -def distance_matrix(a,b): - sum_a = np.sum(a**2, axis=1, keepdims=True) - sum_b = np.sum(b**2, axis=1, keepdims=True) - dist = np.sqrt(-2 * a.dot(b.T) + sum_a + sum_b) - return dist +def distance_matrix(A,B): + + sum_of_squares_A = np.sum(A**2, axis=1,keepdims=True) + sum_of_squares_B = np.sum(B**2, axis=1,keepdims=True).T + dot_product = np.dot(A, B.T) + + + dists=np.sqrt(sum_of_squares_A+sum_of_squares_B-2*dot_product) + + return dists @@ -30,8 +35,10 @@ def knn_predict(dists, labels_train, k): def evaluate_knn(data_train, labels_train, data_test, labels_test, k): rate = 0 - dist_train_test = distance_matrix(data_train, data_test) + dist_train_test = distance_matrix(data_test, data_train) prediction = knn_predict(dist_train_test, labels_train, k) + print(len(prediction)) + print(len(labels_test)) for j in range(len(prediction)): if prediction[j]==labels_test[j]: rate +=1 @@ -46,6 +53,7 @@ def knn_final(): data_train_f, labels_train_f, data_test_f, labels_test_f = read_cifar.split_dataset(data, labels, 0.9) for k in range_k : + print(k) rate_k = evaluate_knn(data_train_f, labels_train_f, data_test_f, labels_test_f, k) rates.append(rate_k) @@ -67,5 +75,4 @@ if __name__ == "__main__" : knn_final() #a1 = np.array([[0,0,1],[0,0,0],[1,1,2]]) #b1 = np.array([[1,3,1], [1,1,4], [1,5,1]]) - #print(distance_matrix(a1,b1)) - + #print(distance_matrix(a1,b1)) \ No newline at end of file -- GitLab