Skip to content
Snippets Groups Projects
Commit 8cb83950 authored by BaptisteBrd's avatar BaptisteBrd
Browse files

modif knn

parent 12ff5e60
Branches main
No related tags found
No related merge requests found
...@@ -3,26 +3,21 @@ import read_cifar ...@@ -3,26 +3,21 @@ import read_cifar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def distance_matrix(A,B): def distance_matrix(A,B):
# Calculating the squared sum of elements in each row for matrices A and B
sum_of_squares_A = np.sum(A**2, axis=1,keepdims=True) sum_of_squares_A = np.sum(A**2, axis=1,keepdims=True)
sum_of_squares_B = np.sum(B**2, axis=1,keepdims=True).T sum_of_squares_B = np.sum(B**2, axis=1,keepdims=True).T
dot_product = np.dot(A, B.T) dot_product = np.dot(A, B.T)
# Computing the Euclidean distance matrix
dists=np.sqrt(sum_of_squares_A+sum_of_squares_B-2*dot_product) dists=np.sqrt(sum_of_squares_A+sum_of_squares_B-2*dot_product)
return dists return dists
#def knn_predict(dists, labels_train, k):
#
#
def knn_predict(dists, labels_train, k): def knn_predict(dists, labels_train, k):
predicted_labels = [] predicted_labels = []
# For every image in the test set # Iterating through each test data point's distances to train data
for i in range(len(dists)): for i in range(len(dists)):
# Initialize an array to store the neighbors # Counting the frequency of each class among the k nearest neighbors
classes = [0] * 10 classes = [0] * 10
# indexes of the closest neighbors # indexes of the closest neighbors
indexes_closest_nb = np.argsort(dists[i])[:k] indexes_closest_nb = np.argsort(dists[i])[:k]
...@@ -37,8 +32,7 @@ def evaluate_knn(data_train, labels_train, data_test, labels_test, k): ...@@ -37,8 +32,7 @@ def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
rate = 0 rate = 0
dist_train_test = distance_matrix(data_test, data_train) dist_train_test = distance_matrix(data_test, data_train)
prediction = knn_predict(dist_train_test, labels_train, k) prediction = knn_predict(dist_train_test, labels_train, k)
print(len(prediction)) # Comparing predictions to actual test labels to calculate accuracy
print(len(labels_test))
for j in range(len(prediction)): for j in range(len(prediction)):
if prediction[j]==labels_test[j]: if prediction[j]==labels_test[j]:
rate +=1 rate +=1
...@@ -52,11 +46,15 @@ def knn_final(): ...@@ -52,11 +46,15 @@ def knn_final():
data,labels = read_cifar.read_cifar("data/cifar-10-batches-py") data,labels = read_cifar.read_cifar("data/cifar-10-batches-py")
data_train_f, labels_train_f, data_test_f, labels_test_f = read_cifar.split_dataset(data, labels, 0.9) data_train_f, labels_train_f, data_test_f, labels_test_f = read_cifar.split_dataset(data, labels, 0.9)
# Testing KNN for different values of k and storing the accuracy
for k in range_k : for k in range_k :
print(k) print(k)
rate_k = evaluate_knn(data_train_f, labels_train_f, data_test_f, labels_test_f, k) rate_k = evaluate_knn(data_train_f, labels_train_f, data_test_f, labels_test_f, k)
rates.append(rate_k) rates.append(rate_k)
# Plotting the accuracy as a function of k
plt.figure(figsize=(10, 7)) plt.figure(figsize=(10, 7))
plt.xlabel('k') plt.xlabel('k')
plt.ylabel('Accuracy rate') plt.ylabel('Accuracy rate')
...@@ -73,6 +71,3 @@ def knn_final(): ...@@ -73,6 +71,3 @@ def knn_final():
if __name__ == "__main__" : if __name__ == "__main__" :
knn_final() knn_final()
\ No newline at end of file
#a1 = np.array([[0,0,1],[0,0,0],[1,1,2]])
#b1 = np.array([[1,3,1], [1,1,4], [1,5,1]])
#print(distance_matrix(a1,b1))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment