diff --git a/knn.py b/knn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3d6e896ebec5c2aad32ea8066a327849e08bc22
--- /dev/null
+++ b/knn.py
@@ -0,0 +1,125 @@
+import numpy as np
+import matplotlib.pyplot as plt
+
+def distance_matrix(M1, M2) :
+    M1_2 = np.sum(M1**2, axis = 1, keepdims = True)
+    M2_2 = np.sum(M2**2, axis = 1, keepdims = True)
+    M1M2 = np.dot(M1, M2.T)
+    dists = np.sqrt(M1_2 + M2_2.T - 2*M1M2)
+    return dists
+
+def knn_predict(dists, labels_train, k) :
+    predict_labels = np.zeros(dists.shape[0])
+    for i in range(dists.shape[0]) :   
+        # On trouve les k indexs les plus proches sur la ligne i de dists
+        k_indexes = np.argpartition(dists[i,:], range(k))[:k]
+        # On récupère les labels des images
+        k_labels = labels_train[k_indexes]
+        # On compte les occurences des labels dans les k voisins
+        unique_labels, counts = np.unique(k_labels, return_counts=True)
+        # On prend le label qui revient le plus souvent
+        predict_labels[i] = unique_labels[np.argmax(counts)]
+    return predict_labels
+
+def evaluate_knn(data_train, labels_train, data_test, labels_test, k) :
+    dists = distance_matrix(data_test, data_train)
+    predict_labels = knn_predict(dists, labels_train, k)
+    acc = 0
+    for i in range(labels_test.shape[0]) :
+        if abs(predict_labels[i] - labels_test[i,0]) < 10**(-6) : # on prend en compte les valeurs presque nulles
+            acc += 1/len(predict_labels)
+    return acc
+
+
+
+
+
+##
+# COPY OF READCIFAR.PY AS I WAS UNABLE TO IMPORT IT
+
+import numpy as np
+import os
+import pickle
+import random
+
+def unpickle(file):
+    import pickle
+    with open(file, 'rb') as f:
+        dict = pickle.load(f, encoding='bytes')
+    return dict
+
+def read_cifar_batch(batch_path) :
+    with open(batch_path, 'rb') as file:
+        # On unpickle le batch
+        batch = pickle.load(file, encoding='bytes')
+        
+        # Extraction de data et labels
+        data = np.array(batch[b'data'], dtype=np.float32)/255.0
+        labels = np.array(batch[b'labels'], dtype = np.int64)
+        
+    return data, labels
+
+def read_cifar(batch_dir):
+    data_batches = []
+    label_batches = []
+
+    # Itération sur les batches
+    for file_name in os.listdir(batch_dir):
+        if file_name.startswith("data_batch") or file_name.startswith("test_batch") :
+            batch_path = os.path.join(batch_dir, file_name)
+            data, labels = read_cifar_batch(batch_path)
+            data_batches.append(data)
+            label_batches.append(labels)
+    
+    # On combine data et labels depuis tous les batches
+    data = np.concatenate(data_batches, axis=0)
+    labels = np.concatenate(label_batches, axis=0)
+    
+    return data, labels
+    
+def split_dataset(data, labels, split):
+    # On vérifie la bonne dimension de data et labels
+    if data.shape[0] != labels.shape[0]:
+        return OSError("data et labels doivent avoir le même nombre de lignes !")
+    
+    # On détermine la taille des data train et test
+    train_size = round(data.shape[0]*split)
+    
+    # On shuffle les data et labels
+    shuffle_index = [i for i in range(data.shape[0])]
+    
+    # On extirpe les data/labels train et test
+    data_train = data[shuffle_index][:train_size]
+    labels_train = np.array([[labels[i]] for i in shuffle_index])[:train_size]
+    data_test = data[shuffle_index][train_size:]
+    labels_test = np.array([[labels[i]] for i in shuffle_index])[train_size:]
+
+    return data_train, labels_train, data_test, labels_test
+
+##
+
+
+
+
+    
+if __name__ == "__main__" :
+    data_folder = 'C:\\Users\\hugol\\Desktop\\Centrale Lyon\\Centrale Lyon 4A\\Informatique\\Machine Learning\\BE1\\cifar-10-batches-py'
+    batch_filename =  'data_batch_1'
+    
+    data_all, labels_all = read_cifar(data_folder)
+    print("Data shape:", data_all.shape)
+    print("Labels shape:", labels_all.shape)
+    
+    data_train, labels_train, data_test, labels_test = split_dataset(data_all, labels_all, 0.9)
+    print("Data train shape:", data_train.shape)
+    print("Labels train shape:", labels_train.shape)
+    print("Data test shape:", data_test.shape)
+    print("Labels test shape:", labels_test.shape)
+
+    acc = np.zeros(20)
+    for k in range(1, 21) :
+        acc[k-1] = evaluate_knn(data_train, labels_train, data_test, labels_test, k)
+    
+    plt.figure()
+    plt.plot(range(1,21), acc)
+    plt.show()