diff --git a/knn.py b/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..7748c63c1f43d69e4a8344a7822f94a39410e2cf --- /dev/null +++ b/knn.py @@ -0,0 +1,77 @@ +import pickle +import numpy as np +import read_cifar as rd +from math import * +import matplotlib.pyplot as plt + +def distance_matrix(a,b): + at=np.transpose(a) + bt=np.transpose(b) + som_carr_a=np.sum(np.square(a), axis=1, keepdims=True) + som_carr_b=np.sum(np.square(b), axis=1, keepdims=True) + + prod=np.dot(a,bt) + + return ( np.sqrt(som_carr_a + np.transpose(som_carr_b) - 2 * prod) ) + + + + +def knn_predict(dists,labels_train,k): + predict=[] + for i in range(len(dists)): + Glob_dist=[] + Glob_min=[] + for j in range(len(dists[i])): + Glob_dist.append(dists[i][j]) + for p in range(k): + m = min(Glob_dist) + index = Glob_dist.index(m) + Glob_min.append(labels_train[index]) + del(Glob_dist[index]) + + Temp=np.bincount(Glob_min) + predict.append(list(Temp).index(max(Temp))) + + + return predict + + +def evaluate_knn(data_train,data_test,labels_train,labels_test,k,dist): + pred=knn_predict(dist, labels_train, k) + tot=0 + bon=0 + for i in range(len(labels_test)): + if labels_test[i]==pred[i]: + bon+=1 + tot+=1 + + print("Accuracy :",bon/tot) + return bon/tot + + +if __name__ == "__main__": + #a=np.random.random((2,2)) + #b=np.random.random((2,2)) + #c=np.eye(2) + + path="data\\cifar-10-batches-py" + labels,data=rd.read_cifar(path) + (data_train,labels_train,data_test,labels_test)=rd.split_dataset(labels,data,0.1) + + data_train,data_test=data_train.astype(np.float32),data_test.astype(np.float32) + labels_train,labels_test=np.array(labels_train, dtype=np.int16),np.array(labels_test, dtype=np.int16) + print("Taille train : ",len(data_train)) + print("Taille test :",len(data_test)) + dist=distance_matrix(data_test,data_train) + Res=[] + for k in range(1,21): + print(k,":") + Res.append(evaluate_knn(data_train, data_test, labels_train, labels_test, k,dist)) + + + plt.figure() + plt.title("Accurracy with different k") + plt.plot([i for i in range (1,21)],Res) + + \ No newline at end of file