Skip to content
Snippets Groups Projects
Commit 19373e89 authored by BaptisteBrd's avatar BaptisteBrd
Browse files

exchange data test and train

parent ce26358c
No related branches found
No related tags found
No related merge requests found
...@@ -2,11 +2,16 @@ import numpy as np ...@@ -2,11 +2,16 @@ import numpy as np
import read_cifar import read_cifar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def distance_matrix(a,b): def distance_matrix(A,B):
sum_a = np.sum(a**2, axis=1, keepdims=True)
sum_b = np.sum(b**2, axis=1, keepdims=True) sum_of_squares_A = np.sum(A**2, axis=1,keepdims=True)
dist = np.sqrt(-2 * a.dot(b.T) + sum_a + sum_b) sum_of_squares_B = np.sum(B**2, axis=1,keepdims=True).T
return dist dot_product = np.dot(A, B.T)
dists=np.sqrt(sum_of_squares_A+sum_of_squares_B-2*dot_product)
return dists
...@@ -30,8 +35,10 @@ def knn_predict(dists, labels_train, k): ...@@ -30,8 +35,10 @@ def knn_predict(dists, labels_train, k):
def evaluate_knn(data_train, labels_train, data_test, labels_test, k): def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
rate = 0 rate = 0
dist_train_test = distance_matrix(data_train, data_test) dist_train_test = distance_matrix(data_test, data_train)
prediction = knn_predict(dist_train_test, labels_train, k) prediction = knn_predict(dist_train_test, labels_train, k)
print(len(prediction))
print(len(labels_test))
for j in range(len(prediction)): for j in range(len(prediction)):
if prediction[j]==labels_test[j]: if prediction[j]==labels_test[j]:
rate +=1 rate +=1
...@@ -46,6 +53,7 @@ def knn_final(): ...@@ -46,6 +53,7 @@ def knn_final():
data_train_f, labels_train_f, data_test_f, labels_test_f = read_cifar.split_dataset(data, labels, 0.9) data_train_f, labels_train_f, data_test_f, labels_test_f = read_cifar.split_dataset(data, labels, 0.9)
for k in range_k : for k in range_k :
print(k)
rate_k = evaluate_knn(data_train_f, labels_train_f, data_test_f, labels_test_f, k) rate_k = evaluate_knn(data_train_f, labels_train_f, data_test_f, labels_test_f, k)
rates.append(rate_k) rates.append(rate_k)
...@@ -68,4 +76,3 @@ if __name__ == "__main__" : ...@@ -68,4 +76,3 @@ if __name__ == "__main__" :
#a1 = np.array([[0,0,1],[0,0,0],[1,1,2]]) #a1 = np.array([[0,0,1],[0,0,0],[1,1,2]])
#b1 = np.array([[1,3,1], [1,1,4], [1,5,1]]) #b1 = np.array([[1,3,1], [1,1,4], [1,5,1]])
#print(distance_matrix(a1,b1)) #print(distance_matrix(a1,b1))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment