diff --git a/knn.py b/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..6c1936149616bec516899f95137c135bf316a821 --- /dev/null +++ b/knn.py @@ -0,0 +1,71 @@ +import numpy as np +import tensorflow as tf +import pandas as pd +import pickle +import os +import scipy +from sklearn.model_selection import train_test_split +from sklearn.neighbors import KNeighborsRegressor +import matplotlib.pyplot as plt + + +def distance_matrix(data_test,data_train): + + dists = np.array([np.sum((data_train-l)**2,axis=1)**.5 for l in data_test]) + + return dists +#receives a 2d array data_train(M,k) and a data_test (N,k), +#returning a 2d array(N,M) such that dists[i,j] represents +#the distance between the i-th data_test row and the j-th data_train row +#in resume, each column represent a distance of a training point to all other + +def knn_predict(dists , labels_train , k): + #classif = np.array(0) + print(labels_train[:20]) + print(labels_train.size) + classif = [] + + for testRows in dists.T: + + distances = np.stack((testRows,labels_train),axis = 1) + distances = distances[distances[:, 0].argsort()] + #for picturesClasses in distances[:k,1]: + countArray = [np.count_nonzero(distances[:k,1]==i) for i in range(0,10)] + classif = np.append(classif,np.argmax(countArray)) + + classif = np.array(classif , dtype = int) + + return classif + +def evaluate_knn(data_train,labels_train,data_test,labels_test,k): + + classif = np.array(knn_predict(distance_matrix(data_train,data_test) , labels_train , k)) + result = np.array(classif == labels_test) + acc = np.count_nonzero(result) / np.size(result) + + return acc*100 + + + +datas,labels = read_cifar_batch('data_batch_1') +print(datas.shape,labels.shape) +dataTrain,dataTest,labelsTrain,labelsTest = split_dataset(datas,labels) +print(dataTrain.shape,dataTest.shape,labelsTrain.shape) +distanceMatrix = distance_matrix(dataTrain,dataTest) +print(distanceMatrix.shape) +print() + +result = [] +for i in range (1,21): + result = np.append(result,evaluate_knn(dataTrain,labelsTrain,dataTest,labelsTest,i)) + +x = np.arange(1, 21) + +# plotting +plt.title("Plot graph") +plt.xlabel("K neighbors") +plt.ylabel("Accuracy %") +plt.plot(x, result, color ="red") +plt.show() + +