diff --git a/knn.py b/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d6e896ebec5c2aad32ea8066a327849e08bc22 --- /dev/null +++ b/knn.py @@ -0,0 +1,125 @@ +import numpy as np +import matplotlib.pyplot as plt + +def distance_matrix(M1, M2) : + M1_2 = np.sum(M1**2, axis = 1, keepdims = True) + M2_2 = np.sum(M2**2, axis = 1, keepdims = True) + M1M2 = np.dot(M1, M2.T) + dists = np.sqrt(M1_2 + M2_2.T - 2*M1M2) + return dists + +def knn_predict(dists, labels_train, k) : + predict_labels = np.zeros(dists.shape[0]) + for i in range(dists.shape[0]) : + # On trouve les k indexs les plus proches sur la ligne i de dists + k_indexes = np.argpartition(dists[i,:], range(k))[:k] + # On récupère les labels des images + k_labels = labels_train[k_indexes] + # On compte les occurences des labels dans les k voisins + unique_labels, counts = np.unique(k_labels, return_counts=True) + # On prend le label qui revient le plus souvent + predict_labels[i] = unique_labels[np.argmax(counts)] + return predict_labels + +def evaluate_knn(data_train, labels_train, data_test, labels_test, k) : + dists = distance_matrix(data_test, data_train) + predict_labels = knn_predict(dists, labels_train, k) + acc = 0 + for i in range(labels_test.shape[0]) : + if abs(predict_labels[i] - labels_test[i,0]) < 10**(-6) : # on prend en compte les valeurs presque nulles + acc += 1/len(predict_labels) + return acc + + + + + +## +# COPY OF READCIFAR.PY AS I WAS UNABLE TO IMPORT IT + +import numpy as np +import os +import pickle +import random + +def unpickle(file): + import pickle + with open(file, 'rb') as f: + dict = pickle.load(f, encoding='bytes') + return dict + +def read_cifar_batch(batch_path) : + with open(batch_path, 'rb') as file: + # On unpickle le batch + batch = pickle.load(file, encoding='bytes') + + # Extraction de data et labels + data = np.array(batch[b'data'], dtype=np.float32)/255.0 + labels = np.array(batch[b'labels'], dtype = np.int64) + + return data, labels + +def read_cifar(batch_dir): + data_batches = [] + label_batches = [] + + # Itération sur les batches + for file_name in os.listdir(batch_dir): + if file_name.startswith("data_batch") or file_name.startswith("test_batch") : + batch_path = os.path.join(batch_dir, file_name) + data, labels = read_cifar_batch(batch_path) + data_batches.append(data) + label_batches.append(labels) + + # On combine data et labels depuis tous les batches + data = np.concatenate(data_batches, axis=0) + labels = np.concatenate(label_batches, axis=0) + + return data, labels + +def split_dataset(data, labels, split): + # On vérifie la bonne dimension de data et labels + if data.shape[0] != labels.shape[0]: + return OSError("data et labels doivent avoir le même nombre de lignes !") + + # On détermine la taille des data train et test + train_size = round(data.shape[0]*split) + + # On shuffle les data et labels + shuffle_index = [i for i in range(data.shape[0])] + + # On extirpe les data/labels train et test + data_train = data[shuffle_index][:train_size] + labels_train = np.array([[labels[i]] for i in shuffle_index])[:train_size] + data_test = data[shuffle_index][train_size:] + labels_test = np.array([[labels[i]] for i in shuffle_index])[train_size:] + + return data_train, labels_train, data_test, labels_test + +## + + + + + +if __name__ == "__main__" : + data_folder = 'C:\\Users\\hugol\\Desktop\\Centrale Lyon\\Centrale Lyon 4A\\Informatique\\Machine Learning\\BE1\\cifar-10-batches-py' + batch_filename = 'data_batch_1' + + data_all, labels_all = read_cifar(data_folder) + print("Data shape:", data_all.shape) + print("Labels shape:", labels_all.shape) + + data_train, labels_train, data_test, labels_test = split_dataset(data_all, labels_all, 0.9) + print("Data train shape:", data_train.shape) + print("Labels train shape:", labels_train.shape) + print("Data test shape:", data_test.shape) + print("Labels test shape:", labels_test.shape) + + acc = np.zeros(20) + for k in range(1, 21) : + acc[k-1] = evaluate_knn(data_train, labels_train, data_test, labels_test, k) + + plt.figure() + plt.plot(range(1,21), acc) + plt.show()