From a28545d911dce6afa556e235ee2a9796afc63305 Mon Sep 17 00:00:00 2001 From: Milan <milan.cart@ecl20.ec-lyon.fr> Date: Thu, 9 Nov 2023 22:15:19 +0100 Subject: [PATCH] Part 2 : KNN --- __pycache__/read_cifar.cpython-311.pyc | Bin 2090 -> 2710 bytes knn.py | 80 +++++++++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/__pycache__/read_cifar.cpython-311.pyc b/__pycache__/read_cifar.cpython-311.pyc index 3c4534351e5e03b53d76e9087a1ffc193f08cd06..3d053de8f665020f45e54022379f58efabed47fd 100644 GIT binary patch delta 1055 zcmZ1_Fio_6IWI340|NuYOncwdJ**52k3k$5W`i<5pJ8BNn9h*GkiwY55XA_lnWC6d z7*d#Wm~&a8SQsI4tWj)mIp!9IDE1VV7KSK}6xJ4oD9%)_6t-+8kcmZhObn^)DeTJ_ z8R{8UGlGObIEIOVA&NVShlwGTBZXrbBLj9h&Q$JHo)pey%nS^xnL+wMIEICRA&NJO zFN!}(fQcbhFqJozFO@%4U>OSo!)g|g3J_*wNMQ^HV@<A?AOSy-m5jFpN{SLQ^WsZV zi%a5*3vx0`Htx`5Vijj#VBp;RjA=I$ABcI2EwP{=H7})zck+5xNj4D%28JTh$>&(L z#KagF7;bSDr6#7tCugQ578NNoFffQr=3{$TpA2^&0|NsnFhTs!F5ti{f%}_*p@yM` zF@>>)sfvq%A%$rTGpegWeye4!WvOMYWvgMXVX0xQVM}8PX3%8uE0P4c0~Dy5Ot&~8 zKFF*}y~R_Mn3s}YpBrCXl30>j1hQ6>y$IrHkaJnV{=CH=5fAqJEw%`-_iwRR!UQV8 z0>$hM3=9eoPz3fUFIZ=MN@7W3acW7CE&~IDCIbUQu>}JILj%KIb@Lr8JJK!$23-h^ zx@aC4bs;PJf_lzH^_(l}ITu+}8r(a=CWOyOd%!I;!EJ{7g0Kr*Iu}@Uic~<3;)svW z1w};sWL|bPX2UA~$z4q5llj>t>eFY<1c8_KnnJf&^Gb6IDvNkQ9_I#W6EDurNlh%u z)630INzI8bPR&V8F3HT#D*{FPEw+Nhyp+V^TTBHhMWA>D#}3%pw>aYC^AdAY<Kqzl zQv~u$krv2oZ4jXYvVb+MpeQr1WF-UGnc!gf#bJ}1pHiBWYFA{&z`y{C%Hn4X3=AKb z85tRGFbG|Mp&Jam4PbbK!Q=uey1^iP0fs)XF);E?Ud^sruf@!$`+)&Y@G!9ObYxv( zmbkzy@jyuFf{@Y$Wv`1u-dBXY8*Lg)I>c@W$xP6?Ae(nlDF2F3K2%K2ctP$4mmO*c zlrO5eT~Twp$gJF8-r+XEeMZ;~W{w8a4)F<UGt?KDU1Zj}!mM=xhCXmJaIr{UV3q)T F6aY9$3%LLQ delta 528 zcmbOxx=NsaIWI340|Ns?kd<%heP#xR#~=<2vqBl4+ZY%arZc24q%h_%L@|PCrYNQq zrWS@M<`m`@hA5U)))ba(CXo6f9VUiUwiMQ7j0_B`89^c-%*X%|WnxHWPGLimv1DLi zK;kkmFhCThu%nBnFa|S#uqMY#ka>Qag%~xNSVb8a7+5zuFz;s41u<{2B^DH<=A{&I zGcYjR;!e)bOHM3F%}Xpv)nvQHQjl1ZQ6$8`z)&PSS%XbWOa!Eet0*-wB|bSbEwQLb zk%56hU~&fAJHcBV@$tEdnR)T?w^(yCa}x6=cXO!K8{OhB&MM5v&Pqwk&Pqzj%}UZ_ zzr~tYnp;o_wy}s4WH`tnV52~W<R#{&f($7r%FHWS$xx&OQUS8~7l%!5eoARhs$G#5 z*xL*Y#eED63?G;o85wUd2wi}o2MoLoVEAG3drrA14F(pTj;u?}5*L^yZZLB+n0AOa z+O^m<*gXJoE^tbMkd)&^PNyrJP8XS-uP{4bV0MN|N?l}@zQQbh0fs(sGq6ZrV3q(o F3jh+Je4hXS diff --git a/knn.py b/knn.py index e69de29..0a687fb 100644 --- a/knn.py +++ b/knn.py @@ -0,0 +1,80 @@ +import read_cifar +import numpy as np +import matplotlib.pyplot as plt + +def distance_matrix(matrix1, matrix2): + #X_test then X_train in this order + sum_of_squares_matrix1 = np.sum(np.square(matrix1), axis=1, keepdims=True) #A^2 + sum_of_squares_matrix2 = np.sum(np.square(matrix2), axis=1, keepdims=True) #B^2 + + dot_product = np.dot(matrix1, matrix2.T) # A * B (matrix mutliplication) + + dists = np.sqrt(sum_of_squares_matrix1 + sum_of_squares_matrix2.T - 2 * dot_product) # Compute the product + return dists + +def knn_predict(dists, labels_train, k): + output = [] + # Loop on all the images_test + for i in range(len(dists)): + # Innitialize table to store the neighbors + res = [0] * 10 + # Get the closest neighbors + labels_close = np.argsort(dists[i])[:k] + for label in labels_close: + #add a label to the table of result + res[labels_train[label]] += 1 + # Get the class with the maximum neighbors + label_temp = np.argmax(res) #Careful to the logic here, if there is two or more maximum, the function the first maximum encountered + output.append(label_temp) + return(np.array(output)) + +def evaluate_knn(data_train, labels_train, data_test, labels_tests, k): + dist = distance_matrix(data_test, data_train) + result_test = knn_predict(dist, labels_train, k) + + #accuracy + N = labels_tests.shape[0] + accuracy = (labels_tests == result_test).sum() / N + return(accuracy) + +def bench_knn() : + + k_indices = [i for i in range(20) if i % 2 != 0] + accuracies = [] + + # Load data + data, labels = read_cifar.read_cifar('/Users/milancart/Documents/GitHub/image-classification/Data/cifar-10-batches-py') + X_train, X_test, y_train, y_test = read_cifar.split_dataset(data, labels, 0.9) + #Load one batch + # data, labels = read_cifar.read_cifar_batch('image-classification/data/cifar-10-batches-py/data_batch_1') + # X_train, X_test, y_train, y_test = read_cifar.split_dataset(data, labels, 0.9) + + # Loop on the k_indices to get all the accuracies + for k in k_indices : + accuracy = evaluate_knn(X_train, y_train, X_test, y_test, k) + accuracies.append(accuracy) + + # Save and show the graph of accuracies + plt.figure(figsize=(8, 6)) + plt.xlabel('K') + plt.ylabel('Accuracy') + plt.plot(k_indices, accuracies) + plt.title("Accuracy as function of k") + plt.legend() + plt.show() + plt.savefig('/Users/milancart/Documents/GitHub/image-classification/result/knn.png') + + +if __name__ == "__main__": + print('milan') + bench_knn() + data, labels = read_cifar.read_cifar('/Users/milancart/Documents/GitHub/image-classification/Data/cifar-10-batches-py') + X_train, X_test, y_train, y_test = read_cifar.split_dataset(data, labels, 0.9) + print(evaluate_knn(X_train, y_train, X_test, y_test, 5)) + print(X_train.shape, X_test.shape, y_train.shape, y_test.shape) + + y_test = [] + x_test = np.array([[1,2],[4,6]]) + x_train = np.array([[2,4],[7,2],[4,6]]) + y_train = [1,2,1] + dist = distance_matrix(x_test,x_train) \ No newline at end of file -- GitLab