Skip to content
Snippets Groups Projects
Commit ce26358c authored by BaptisteBrd's avatar BaptisteBrd
Browse files

knn final

parent 2418acfd
Branches
No related tags found
No related merge requests found
import numpy as np
import read_cifar
import matplotlib.pyplot as plt
def distance_matrix(a,b):
sum_a = np.sum(a**2, axis=1, keepdims=True)
......@@ -27,12 +29,43 @@ def knn_predict(dists, labels_train, k):
return(np.array(predicted_labels))
def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
rate = 0
dist_train_test = distance_matrix(data_train, data_test)
prediction = knn_predict(dist_train_test, labels_train, k)
for j in range(len(prediction)):
if prediction[j]==labels_test[j]:
rate +=1
rate = rate/len(prediction)
return rate
def knn_final():
range_k = range(1,20)
rates = []
data,labels = read_cifar.read_cifar("data/cifar-10-batches-py")
data_train_f, labels_train_f, data_test_f, labels_test_f = read_cifar.split_dataset(data, labels, 0.9)
for k in range_k :
rate_k = evaluate_knn(data_train_f, labels_train_f, data_test_f, labels_test_f, k)
rates.append(rate_k)
plt.figure(figsize=(10, 7))
plt.xlabel('k')
plt.ylabel('Accuracy rate')
plt.plot(range_k, rates)
plt.title("Accuracy rate = f(k)")
plt.legend()
plt.grid(True)
plt.show()
if __name__ == "__main__" :
a1 = np.array([[0,0,1],[0,0,0],[1,1,2]])
b1 = np.array([[1,3,1], [1,1,4], [1,5,1]])
print(distance_matrix(a1,b1))
\ No newline at end of file
knn_final()
#a1 = np.array([[0,0,1],[0,0,0],[1,1,2]])
#b1 = np.array([[1,3,1], [1,1,4], [1,5,1]])
#print(distance_matrix(a1,b1))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment