Skip to content
Snippets Groups Projects
Commit 19373e89 authored by BaptisteBrd's avatar BaptisteBrd
Browse files

exchange data test and train

parent ce26358c
Branches
No related tags found
No related merge requests found
......@@ -2,11 +2,16 @@ import numpy as np
import read_cifar
import matplotlib.pyplot as plt
def distance_matrix(a,b):
sum_a = np.sum(a**2, axis=1, keepdims=True)
sum_b = np.sum(b**2, axis=1, keepdims=True)
dist = np.sqrt(-2 * a.dot(b.T) + sum_a + sum_b)
return dist
def distance_matrix(A,B):
sum_of_squares_A = np.sum(A**2, axis=1,keepdims=True)
sum_of_squares_B = np.sum(B**2, axis=1,keepdims=True).T
dot_product = np.dot(A, B.T)
dists=np.sqrt(sum_of_squares_A+sum_of_squares_B-2*dot_product)
return dists
......@@ -30,8 +35,10 @@ def knn_predict(dists, labels_train, k):
def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
rate = 0
dist_train_test = distance_matrix(data_train, data_test)
dist_train_test = distance_matrix(data_test, data_train)
prediction = knn_predict(dist_train_test, labels_train, k)
print(len(prediction))
print(len(labels_test))
for j in range(len(prediction)):
if prediction[j]==labels_test[j]:
rate +=1
......@@ -46,6 +53,7 @@ def knn_final():
data_train_f, labels_train_f, data_test_f, labels_test_f = read_cifar.split_dataset(data, labels, 0.9)
for k in range_k :
print(k)
rate_k = evaluate_knn(data_train_f, labels_train_f, data_test_f, labels_test_f, k)
rates.append(rate_k)
......@@ -68,4 +76,3 @@ if __name__ == "__main__" :
#a1 = np.array([[0,0,1],[0,0,0],[1,1,2]])
#b1 = np.array([[1,3,1], [1,1,4], [1,5,1]])
#print(distance_matrix(a1,b1))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment