diff --git a/knn.py b/knn.py
index 043e7ca81258066d76c9d9e0c509994504739109..54a3d58458ff0f3d7b8a9d94d002d59510276ff3 100644
--- a/knn.py
+++ b/knn.py
@@ -2,12 +2,36 @@ import numpy as np
 
 
 def distance_matrix(mat1, mat2):
-    dists = np.sqrt(np.matmul(mat1, mat1)+ np.matmul(mat2, mat2) - 2 * np.matmul(mat1, mat2))
+    square1 = np.sum(np.square(mat1), axis = 1)
+    square2 = np.sum(np.square(mat2), axis = 1)
+    prod = np.dot(mat1, mat2.T)
+    dists = np.sqrt(square1 + square2 - 2 * prod)
     return dists
 
 def knn_predict(dists, labels_train, k):
-    return predicted_labels
+    # results matrix initialisation
+    predicted_labels = np.zeros(len(dists))
+    # loop on all the test images
+    for i in range(len(dists)):
+        # sort and keep the k shortest dists for test image i
+        sorted_dists = np.argsort(dists[i])
+        k_sorted_dists = sorted_dists[:k]
+        # get the matching labels_train
+        closest_labels = labels_train[k_sorted_dists]
+        # get the most common labels_train
+        predicted_labels[i] = np.argmax(closest_labels)
+    return np.array(predicted_labels)
 
+def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
+    dists = distance_matrix(data_test, data_train)
+    tot = len(data_test)
+    accurate = 0
+    predicted_labels = knn_predict(dists, labels_train, k)
+    for i in range(tot):
+        if predicted_labels[i] == labels_test[i]:
+            accurate += 1
+    accuracy = accurate/tot
+    return accuracy
 
 
 
@@ -15,22 +39,17 @@ def knn_predict(dists, labels_train, k):
 
 
 
-mat1 = np.array([[1, 2],
-                 [3, 4]])
 
-mat2 = np.array([[5, 6],
-                 [7, 8]])
+if __name__ == "__main__":
 
-A = np.matmul(mat1, mat1)
-print(A)
+    bench_knn()
+    # data, labels = read_cifar.read_cifar('image-classification/data/cifar-10-batches-py')
+    # X_train, X_test, y_train, y_test = read_cifar.split_dataset(data, labels, 0.9)
+    # print(evaluate_knn(X_train, y_train, X_test, y_test, 5))
+    # print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
 
-B = np.matmul(mat2, mat2)
-print(B)
-
-C = 2 * np.matmul(mat1, mat2)
-print(C)
-
-print(A + B - C)
-
-mat = distance_matrix(mat1, mat2)
-print(mat)
\ No newline at end of file
+    # y_test = []
+    # x_test = np.array([[1,2],[4,6]])
+    # x_train = np.array([[2,4],[7,2],[4,6]])
+    # y_train = [1,2,1]
+    # dist = distance_matrix(x_test,x_train)