From a34e29fbe80bcf413be18bf93ec76d8af1d4afe4 Mon Sep 17 00:00:00 2001 From: Delorme Antonin <antonin.delorme@etu.ec-lyon.fr> Date: Fri, 10 Nov 2023 19:04:12 +0000 Subject: [PATCH] Upload New File --- knn.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 knn.py diff --git a/knn.py b/knn.py new file mode 100644 index 0000000..7748c63 --- /dev/null +++ b/knn.py @@ -0,0 +1,77 @@ +import pickle +import numpy as np +import read_cifar as rd +from math import * +import matplotlib.pyplot as plt + +def distance_matrix(a,b): + at=np.transpose(a) + bt=np.transpose(b) + som_carr_a=np.sum(np.square(a), axis=1, keepdims=True) + som_carr_b=np.sum(np.square(b), axis=1, keepdims=True) + + prod=np.dot(a,bt) + + return ( np.sqrt(som_carr_a + np.transpose(som_carr_b) - 2 * prod) ) + + + + +def knn_predict(dists,labels_train,k): + predict=[] + for i in range(len(dists)): + Glob_dist=[] + Glob_min=[] + for j in range(len(dists[i])): + Glob_dist.append(dists[i][j]) + for p in range(k): + m = min(Glob_dist) + index = Glob_dist.index(m) + Glob_min.append(labels_train[index]) + del(Glob_dist[index]) + + Temp=np.bincount(Glob_min) + predict.append(list(Temp).index(max(Temp))) + + + return predict + + +def evaluate_knn(data_train,data_test,labels_train,labels_test,k,dist): + pred=knn_predict(dist, labels_train, k) + tot=0 + bon=0 + for i in range(len(labels_test)): + if labels_test[i]==pred[i]: + bon+=1 + tot+=1 + + print("Accuracy :",bon/tot) + return bon/tot + + +if __name__ == "__main__": + #a=np.random.random((2,2)) + #b=np.random.random((2,2)) + #c=np.eye(2) + + path="data\\cifar-10-batches-py" + labels,data=rd.read_cifar(path) + (data_train,labels_train,data_test,labels_test)=rd.split_dataset(labels,data,0.1) + + data_train,data_test=data_train.astype(np.float32),data_test.astype(np.float32) + labels_train,labels_test=np.array(labels_train, dtype=np.int16),np.array(labels_test, dtype=np.int16) + print("Taille train : ",len(data_train)) + print("Taille test :",len(data_test)) + dist=distance_matrix(data_test,data_train) + Res=[] + for k in range(1,21): + print(k,":") + Res.append(evaluate_knn(data_train, data_test, labels_train, labels_test, k,dist)) + + + plt.figure() + plt.title("Accurracy with different k") + plt.plot([i for i in range (1,21)],Res) + + \ No newline at end of file -- GitLab