diff --git a/knn.py b/knn.py index fa4aba5a28931563243d4335dd93a026b007b77c..2833ea71fe0c9caaaacf055131a9d703b99fb5b7 100644 --- a/knn.py +++ b/knn.py @@ -1,4 +1,6 @@ import numpy as np +import read_cifar +import matplotlib.pyplot as plt def distance_matrix(a,b): sum_a = np.sum(a**2, axis=1, keepdims=True) @@ -27,12 +29,43 @@ def knn_predict(dists, labels_train, k): return(np.array(predicted_labels)) def evaluate_knn(data_train, labels_train, data_test, labels_test, k): - + rate = 0 + dist_train_test = distance_matrix(data_train, data_test) + prediction = knn_predict(dist_train_test, labels_train, k) + for j in range(len(prediction)): + if prediction[j]==labels_test[j]: + rate +=1 + rate = rate/len(prediction) + return rate + +def knn_final(): + range_k = range(1,20) + rates = [] + + data,labels = read_cifar.read_cifar("data/cifar-10-batches-py") + data_train_f, labels_train_f, data_test_f, labels_test_f = read_cifar.split_dataset(data, labels, 0.9) + + for k in range_k : + rate_k = evaluate_knn(data_train_f, labels_train_f, data_test_f, labels_test_f, k) + rates.append(rate_k) + + plt.figure(figsize=(10, 7)) + plt.xlabel('k') + plt.ylabel('Accuracy rate') + plt.plot(range_k, rates) + plt.title("Accuracy rate = f(k)") + plt.legend() + plt.grid(True) + plt.show() + if __name__ == "__main__" : - a1 = np.array([[0,0,1],[0,0,0],[1,1,2]]) - b1 = np.array([[1,3,1], [1,1,4], [1,5,1]]) - print(distance_matrix(a1,b1)) \ No newline at end of file + + knn_final() + #a1 = np.array([[0,0,1],[0,0,0],[1,1,2]]) + #b1 = np.array([[1,3,1], [1,1,4], [1,5,1]]) + #print(distance_matrix(a1,b1)) +