Skip to content
Snippets Groups Projects
Commit d4e4acba authored by Muniz Silva Samuel's avatar Muniz Silva Samuel
Browse files

Upload New File

parent 880bf5cd
No related branches found
No related tags found
No related merge requests found
knn.py 0 → 100644
import numpy as np
import tensorflow as tf
import pandas as pd
import pickle
import os
import scipy
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor
import matplotlib.pyplot as plt
def distance_matrix(data_test,data_train):
dists = np.array([np.sum((data_train-l)**2,axis=1)**.5 for l in data_test])
return dists
#receives a 2d array data_train(M,k) and a data_test (N,k),
#returning a 2d array(N,M) such that dists[i,j] represents
#the distance between the i-th data_test row and the j-th data_train row
#in resume, each column represent a distance of a training point to all other
def knn_predict(dists , labels_train , k):
#classif = np.array(0)
print(labels_train[:20])
print(labels_train.size)
classif = []
for testRows in dists.T:
distances = np.stack((testRows,labels_train),axis = 1)
distances = distances[distances[:, 0].argsort()]
#for picturesClasses in distances[:k,1]:
countArray = [np.count_nonzero(distances[:k,1]==i) for i in range(0,10)]
classif = np.append(classif,np.argmax(countArray))
classif = np.array(classif , dtype = int)
return classif
def evaluate_knn(data_train,labels_train,data_test,labels_test,k):
classif = np.array(knn_predict(distance_matrix(data_train,data_test) , labels_train , k))
result = np.array(classif == labels_test)
acc = np.count_nonzero(result) / np.size(result)
return acc*100
datas,labels = read_cifar_batch('data_batch_1')
print(datas.shape,labels.shape)
dataTrain,dataTest,labelsTrain,labelsTest = split_dataset(datas,labels)
print(dataTrain.shape,dataTest.shape,labelsTrain.shape)
distanceMatrix = distance_matrix(dataTrain,dataTest)
print(distanceMatrix.shape)
print()
result = []
for i in range (1,21):
result = np.append(result,evaluate_knn(dataTrain,labelsTrain,dataTest,labelsTest,i))
x = np.arange(1, 21)
# plotting
plt.title("Plot graph")
plt.xlabel("K neighbors")
plt.ylabel("Accuracy %")
plt.plot(x, result, color ="red")
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment