Skip to content
Snippets Groups Projects
Commit cfaf9737 authored by Bourry Malo's avatar Bourry Malo
Browse files

Partie knn finie

parent 9ea34c98
Branches
No related tags found
No related merge requests found
File added
import numpy as np import numpy as np
from read_cifar import read_cifar, split_dataset
import matplotlib.pyplot as plt
def distance_matrix(matrix_a: np.ndarray, matrix_b: np.ndarray): def distance_matrix(matrix_a: np.ndarray, matrix_b: np.ndarray):
sum_squares_1 = np.sum(matrix_a**2, axis = 1, keepdims = True) sum_squares_1 = np.sum(matrix_a**2, axis = 1, keepdims = True)
...@@ -10,11 +12,43 @@ def distance_matrix(matrix_a: np.ndarray, matrix_b: np.ndarray): ...@@ -10,11 +12,43 @@ def distance_matrix(matrix_a: np.ndarray, matrix_b: np.ndarray):
return dists return dists
def knn_predict(dists: np.ndarray, labels_train: np.ndarray, k:int): def knn_predict(dists: np.ndarray, labels_train: np.ndarray, k:int):
return 0 labels_predicts = np.zeros(np.size(dist, 0))
for i in range(np.size(labels_predicts, 0)):
#On extrait les indices des k valeurs plus petites (des k plus proches voisins)
k_neighbors_index = np.argmin(dists[i, :], np.sort(dists[i, :])[:k])
#On compte la classe la plus présente parmi les k voisins les plus proches
labels_k_neighbors = labels_train[k_neighbors_index]
#On compte le nombre d'occurence des classes parmis les k
_, count = np.unique(labels_k_neighbors, return_counts=True)
#On associe à la prédiction la classe la plus presente parmis les k
labels_predicts[i] = labels_k_neighbors[np.argmax(count)]
return labels_predicts
def evaluate_knn(data_train:np.ndarray, labels_train: np.ndarray, data_test:np.ndarray, labels_test:np.ndarray, k:int):
dists = distance_matrix(data_test, data_train)
labels_predicts = knn_predict(dists, labels_train, k)
#calcul de l'accuracy
accuracy = 0
for i in range(np.size(labels_predicts, 0)):
if abs(labels_predicts[i]-labels_test[i])<10**(-7):
accuracy += 1
accuracy /= np.size(labels_predicts, 0)
return accuracy
def plot_knn(data_train:np.ndarray, labels_train: np.ndarray, data_test:np.ndarray, labels_test:np.ndarray, n: int):
accuracy_vector = np.zeros(n)
for k in range(1, n+1):
accuracy_vector[k] = evaluate_knn(data_train, labels_train, data_test, labels_test)
plt.plot(accuracy_vector)
plt.show()
return
if __name__ == "__main__": if __name__ == "__main__":
A = np.ones((3,3)) data, labels = read_cifar()
B = np.ones((3,3))*2 data_train, labels_train, data_test, labels_test = split_dataset(data, labels, 0.8)
dist = distance_matrix(A, B) k = 5 #Nombre de voisins
\ No newline at end of file accuracy = evaluate_knn(data_train, labels_train, data_test, labels_test, k)
\ No newline at end of file
...@@ -31,9 +31,9 @@ def read_cifar(): ...@@ -31,9 +31,9 @@ def read_cifar():
dict = pickle.load(fo, encoding='bytes') dict = pickle.load(fo, encoding='bytes')
data.append(dict[b'data']) data.append(dict[b'data'])
labels.append(dict[b'labels']) labels.append(dict[b'labels'])
data = np.array(data, np.float32) data = np.array(data, np.float16)
labels = np.array(labels, np.int64) labels = np.array(labels, np.int64)
return np.reshape(data, (np.size(data, 0)*np.size(data, 1), np.size(data, 2))), np.reshape(labels, (np.size(labels, 0)*np.size(labels, 1), 1)) return np.reshape(data, (np.size(data, 0)*np.size(data, 1), np.size(data, 2))), np.reshape(labels, (np.size(labels, 0)*np.size(labels, 1)))
def split_dataset(data: np.ndarray, labels: np.ndarray, split: float): def split_dataset(data: np.ndarray, labels: np.ndarray, split: float):
...@@ -50,5 +50,5 @@ def split_dataset(data: np.ndarray, labels: np.ndarray, split: float): ...@@ -50,5 +50,5 @@ def split_dataset(data: np.ndarray, labels: np.ndarray, split: float):
if __name__ == "__main__": if __name__ == "__main__":
data, labels = read_cifar() data, labels = read_cifar()
a, b, c, d = split_dataset(data, labels, 0.8) data_train, labels_train, data_test, labels_test = split_dataset(data, labels, 0.8)
print(1) print(1)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment