diff --git a/knn.py b/knn.py
index 9e4402825cc444d11e6a465ad20fe5db65373eb2..365c08101ccac64ee6514d66cd90e04242a762d8 100644
--- a/knn.py
+++ b/knn.py
@@ -3,26 +3,21 @@ import read_cifar
 import matplotlib.pyplot as plt
 
 def distance_matrix(A,B):
-
+     # Calculating the squared sum of elements in each row for matrices A and B
     sum_of_squares_A = np.sum(A**2, axis=1,keepdims=True)
     sum_of_squares_B = np.sum(B**2, axis=1,keepdims=True).T
     dot_product = np.dot(A, B.T)
 
-
+    # Computing the Euclidean distance matrix
     dists=np.sqrt(sum_of_squares_A+sum_of_squares_B-2*dot_product)
 
     return dists
 
-
-
-#def knn_predict(dists, labels_train, k):
-    #
-    # 
 def knn_predict(dists, labels_train, k):
     predicted_labels = []
-    # For every image in the test set
+    # Iterating through each test data point's distances to train data
     for i in range(len(dists)):
-        # Initialize an array to store the neighbors
+        # Counting the frequency of each class among the k nearest neighbors
         classes = [0] * 10
         # indexes of the closest neighbors
         indexes_closest_nb = np.argsort(dists[i])[:k]
@@ -37,8 +32,7 @@ def evaluate_knn(data_train, labels_train, data_test, labels_test, k):
     rate = 0
     dist_train_test = distance_matrix(data_test, data_train)
     prediction = knn_predict(dist_train_test, labels_train, k)
-    print(len(prediction))
-    print(len(labels_test))
+    # Comparing predictions to actual test labels to calculate accuracy
     for j in range(len(prediction)):
         if prediction[j]==labels_test[j]:
             rate +=1
@@ -51,12 +45,16 @@ def knn_final():
 
     data,labels = read_cifar.read_cifar("data/cifar-10-batches-py")
     data_train_f, labels_train_f, data_test_f, labels_test_f = read_cifar.split_dataset(data, labels, 0.9)
+    
+    # Testing KNN for different values of k and storing the accuracy
 
     for k in range_k :
         print(k)
         rate_k = evaluate_knn(data_train_f, labels_train_f, data_test_f, labels_test_f, k)
         rates.append(rate_k)
 
+    # Plotting the accuracy as a function of k
+
     plt.figure(figsize=(10, 7))
     plt.xlabel('k')
     plt.ylabel('Accuracy rate')
@@ -72,7 +70,4 @@ def knn_final():
 
 if __name__ == "__main__" :
 
-    knn_final()
-    #a1 = np.array([[0,0,1],[0,0,0],[1,1,2]])
-    #b1 = np.array([[1,3,1], [1,1,4], [1,5,1]])
-    #print(distance_matrix(a1,b1))
\ No newline at end of file
+    knn_final()
\ No newline at end of file