From 2418acfde0439fae05e2996b2b499b15069f28f2 Mon Sep 17 00:00:00 2001
From: BaptisteBrd <75663738+BaptisteBrd@users.noreply.github.com>
Date: Fri, 10 Nov 2023 15:01:27 +0100
Subject: [PATCH] knn save

---
 knn.py | 34 +++++++++++++++++++++++++++++-----
 1 file changed, 29 insertions(+), 5 deletions(-)

diff --git a/knn.py b/knn.py
index cffe512..fa4aba5 100644
--- a/knn.py
+++ b/knn.py
@@ -1,14 +1,38 @@
 import numpy as np
 
 def distance_matrix(a,b):
-    sx = np.sum(a**2, axis=1, keepdims=True)
-    sy = np.sum(b**2, axis=1, keepdims=True)
-    dists = np.sqrt(-2 * a.dot(b.T) + sx + sy.T)
-    return dists
+    sum_a = np.sum(a**2, axis=1, keepdims=True)
+    sum_b = np.sum(b**2, axis=1, keepdims=True)
+    dist = np.sqrt(-2 * a.dot(b.T) + sum_a + sum_b)
+    return dist
 
 
 
+#def knn_predict(dists, labels_train, k):
+    #
+    # 
+def knn_predict(dists, labels_train, k):
+    predicted_labels = []
+    # For every image in the test set
+    for i in range(len(dists)):
+        # Initialize an array to store the neighbors
+        classes = [0] * 10
+        # indexes of the closest neighbors
+        indexes_closest_nb = np.argsort(dists[i])[:k]
+        for index in indexes_closest_nb:
+            #find the labels of the training batch associated with the closest indexes
+            classes[labels_train[index]] += 1
+        #The class with the highest neighbors is added to the predicted labels
+        predicted_labels.append(np.argmax(classes))
+    return(np.array(predicted_labels))
+
+def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
+    
+
+
 
 
 if __name__ == "__main__" :
-    
+    a1 = np.array([[0,0,1],[0,0,0],[1,1,2]])
+    b1 = np.array([[1,3,1], [1,1,4], [1,5,1]])
+    print(distance_matrix(a1,b1))
\ No newline at end of file
-- 
GitLab