From f4d3c48d4f5a97c08fb58bde2b6fd3999a2a3d40 Mon Sep 17 00:00:00 2001 From: oscarchaufour <101994223+oscarchaufour@users.noreply.github.com> Date: Fri, 27 Oct 2023 21:27:59 +0200 Subject: [PATCH] knn modifications and mlp learn_once_mse --- knn.py | 71 ++++++++++++++++++++++++++++++++---------------- mlp.py | 40 +++++++++++++++++++++++++++ read_cifar.py | 29 ++++++++++++++++---- results/knn.png | Bin 0 -> 13205 bytes test1.py | 52 +++++++++++++++++++++++++++++++++++ 5 files changed, 163 insertions(+), 29 deletions(-) create mode 100644 mlp.py create mode 100644 results/knn.png create mode 100644 test1.py diff --git a/knn.py b/knn.py index 0be4178..1fd7fa4 100644 --- a/knn.py +++ b/knn.py @@ -8,49 +8,65 @@ import read_cifar import numpy as np import statistics from statistics import mode +import time +import matplotlib.pyplot as plt def distance_matrix(A,B) : - # sum_of_squaresA = np.sum(A ** 2, axis=1) - # sum_of_squaresB = np.sum(B ** 2, axis=1) - sum_of_squaresA = np.sum(np.square(A), axis=1) - sum_of_squaresB = np.sum(np.square(B) ** 2, axis=1) - + print("test0") + sum_of_squaresA= np.sum(A**2, axis = 1, keepdims = True) + sum_of_squaresB = np.sum(B**2, axis = 1) + print("test1") + # sum_of_squaresA = np.tile(sum_of_squaresAVect, (np.shape(B)[0], 1)) + # sum_of_squaresB = np.tile(sum_of_squaresBVect, (np.shape(A)[0], 1)) # Calculate the dot product between the two matrices - dot_product = np.dot(A, B.T) - + # dot_product = np.matmul(A, B.T) + dot_product = np.einsum('ij,jk', A, B.T) + print("test2") # Calculate the Euclidean distance matrix using the hint provided dists = np.sqrt(sum_of_squaresA + sum_of_squaresB - 2 * dot_product) - + print("test3") return dists def knn_predict(dists, labels_train, k) : - number_test, number_train = dists.shape + number_train, number_test = dists.shape # initialze the predicted labels to zeros labels_predicted = np.zeros(number_test) - for i in range(number_test) : - sorted_indices = np.argsort(dists[i]) + for j in range(number_test) : + sorted_indices = np.argsort(dists[:, j]) + print(len(dists[:, j])) + break knn_indices = sorted_indices[ : k] knn_labels = labels_train[knn_indices] label_predicted = mode(knn_labels) - labels_predicted[i] = label_predicted + labels_predicted[j] = label_predicted return labels_predicted def evaluate_knn(data_train, labels_train, data_test, labels_test, k) : - dists = distance_matrix(data_test, data_train) + dists = distance_matrix(data_train, data_test) labels_predicted = knn_predict(dists, labels_train, k) number_true_prediction = np.sum(labels_test == labels_predicted) number_total_prediction = labels_test.shape[0] classification_rate = number_true_prediction/number_total_prediction return classification_rate + +def plot_accuracy(data_train, labels_train, data_test, labels_test, k_max) : + Y = [] + for k in range(1, k_max+1) : + Y += [evaluate_knn(data_train, labels_train, data_test, labels_test, k)] + plt.plot(list(range(1, k_max+1)), Y) + plt.xlabel('k (Number of Neighbors)') + plt.ylabel('Accuracy') + plt.savefig('results/knn.png') + if __name__ == "__main__" : - + t1 = time.time() # # Example distance matrix, training labels, and k value # dists = np.array([[1000, 2, 3], # [4, 0.1, 6], @@ -62,14 +78,23 @@ if __name__ == "__main__" : # predicted_labels = knn_predict(dists, labels_train, k) - classification_rate = evaluate_knn(np.array([[1, 27], [100, 300]]), np.array([0.002, 9000]), np.array([[25, 350]]), np.array([9000]), 1) - print("Classification rate:") - print(classification_rate) - - # file = "./data/cifar-10-python/" - # data, labels = read_cifar.read_cifar(file) - # data_train, labels_train, data_test, labels_test = read_cifar.split_dataset(data, labels, 0.8) + # classification_rate = evaluate_knn(np.array([[1, 27], [100, 300]]), np.array([0.002, 9000]), np.array([[25, 350]]), np.array([9000]), 1) + # print("Classification rate:") + # print(classification_rate) + file = "./data/cifar-10-python/" + data, labels = read_cifar.read_cifar(file) + data_train, labels_train, data_test, labels_test = read_cifar.split_dataset(data, labels, 0.9) + k = 10 + print(len(data_train)) + print(len(data_test)) + print(len(data_train[0])) + print(len(data_test[0])) # dists = distance_matrix(data_train, data_test) - # k = 2 - # knn_predict(dists, labels_train, k) \ No newline at end of file + # knn_predict(dists, labels_train, k) + classification_rate = evaluate_knn(data_train, labels_train, data_test, labels_test, k) + print("classification rate :", classification_rate) + # plot_accuracy(data_train, labels_train, data_test, labels_test, 4) + t2 = time.time() + print('run time (second): ') + print(t2-t1) \ No newline at end of file diff --git a/mlp.py b/mlp.py new file mode 100644 index 0000000..825ce39 --- /dev/null +++ b/mlp.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Oct 27 16:48:16 2023 + +@author: oscar +""" + +import numpy as np + + +def learn_once_mse(w1, b1, w2, b2, data, targets, learning_rate) : + a0 = data # the data are the input of the first layer + z1 = np.matmul(a0, w1) + b1 # input of the hidden layer + a1 = 1 / (1 + np.exp(-z1)) # output of the hidden layer (sigmoid activation function) + z2 = np.matmul(a1, w2) + b2 # input of the output layer + a2 = 1 / (1 + np.exp(-z2)) # output of the output layer (sigmoid activation function) + predictions = a2 # the predicted values are the outputs of the output layer + N = targets.shape[0] + + # calculation of partial derivates of C + dCdA2 = 2/N * (a2 - targets) + dCdZ2 = dCdA2 * (a2 - a2**2) + dCdW2 = np.matmul(a1.T, dCdZ2) + dCdB2 = (1/N) * np.sum(dCdZ2, axis=0, keepdims=True) + dCdA1 = np.matmul(dCdZ2, w2.T) + dCdZ1 = dCdA1 * (a1 - a1**2) + dCdW1 = np.matmul(a0.T, dCdZ1) + dCdB1 = (1/N) * np.sum(dCdZ1, axis=0, keepdims=True) + + # one gradient descent step + w1 -= dCdW1 * learning_rate + b1 -= dCdB1 * learning_rate + w2 -= dCdW2 * learning_rate + b2 -= dCdB2 * learning_rate + + loss = np.mean(np.square(predictions - targets)) + + return w1, b1, w2, b2, loss + + \ No newline at end of file diff --git a/read_cifar.py b/read_cifar.py index 6d5369d..f25324d 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -24,24 +24,41 @@ def read_cifar (batch_dir) : data_batches = [] label_batches = [] - for i in range(1,6) : + for i in range(1,4) : batch_filename = f'data_batch_{i}' batch_path = os.path.join(batch_dir, batch_filename) data, labels = read_cifar_batch(batch_path) data_batches.append(data) label_batches.append(labels) - test_batch_filename = 'test_batch' - test_batch_path = os.path.join(batch_dir, test_batch_filename) - data_test, labels_test = read_cifar_batch(test_batch_path) - data_batches.append(data_test) - label_batches.append(labels_test) + # test_batch_filename = 'test_batch' + # test_batch_path = os.path.join(batch_dir, test_batch_filename) + # data_test, labels_test = read_cifar_batch(test_batch_path) + # data_batches.append(data_test) + # label_batches.append(labels_test) data = np.concatenate(data_batches, axis=0) labels = np.concatenate(label_batches, axis=0) return data, labels +# def read_cifar(directory_path): +# batches = os.listdir(directory_path) +# data=None +# labels=None + +# for batch in batches: +# batch_path = os.path.join(directory_path, batch) +# if not batch_path.endswith(".meta"): +# data_batch,labels_batch=read_cifar_batch(batch_path) +# if data is None: +# data=data_batch +# labels=labels_batch +# else: +# data=np.concatenate((data,data_batch)) +# labels=np.concatenate((labels,labels_batch)) +# return(data, labels) + def split_dataset(data, labels, split) : number_total = data.shape[0] diff --git a/results/knn.png b/results/knn.png new file mode 100644 index 0000000000000000000000000000000000000000..393d194fbe612e40cf83ef7dd8f14d34f3fc698b GIT binary patch literal 13205 zcmeAS@N?(olHy`uVBq!ia0y~yVBEmKz^K5%#=yXERYNV4fq}uY#5JNMI6tkVJh3R1 z!8fs_ASb^hCo@T*EVZaOGe6H*&s@(?M<Ju6q@dVJUq2VFKrg>2UH>*0M>_)p180Fp zWHAE+w=f7ZGR&GI!N6e5<muuVQZeW4-O7-V@I|)&VqF=#xSF~cxtjxeLV1ih-HxtU zVd!Pzwtv!-<%J#m(aA!_Gp@dzTJn9fbs1Z`i%^q)uCLVt<E-R4w*))4bu<XEP7QKo z3!K>DqZLqESN|hOmPt#YYUj~^Yp$-|{mz#C+sfOUY(px&6<Y)pl$0h(F1^Co(a~|M ziPg=;Wz!KRWhJG^q((u(i8>n(baZ%3i%4*BQSsI>P*R$tDkdQ)cv6VFr=ugKbMPr& zGRf+|;!C^B^LOmp<t5wP;St8g#Z~b4SLw>+Egd(-W~^SVZDeG$q|ebsMO#r(F~~i` z@E_MiG{0m8H@>~MS2`{(?m~u1#IBOe$)D5D&vSKgVOg_w?Z=nP=Wok+ci#Sg&x#cq z6%`c<T3TG@d3Oxv{(iBzf5F9!2Tz~2*8hDSpWEr<;*z%PDr=g7ghBnknm4z$a_`)^ z)5FKdB`+^;u^}HH-;Y0^&wu=OJAZq{yGu*G4U3+5y!cwxBV&0fBoCz0%gf2nkFV_g zz25tO-<2Ck@Z`qo3kuHc3=@y2s<JYX;=Oa{PC{yGYv$!;ik6l;FG{$(ySKEp9lCUB z)6wF|lP4P{ALC&-&?9O5;e7qS%Oy3pwq$PDw(VF0BlCib858EsyEkpCkl;kg%&)g! zczbz0c=RZ#YE9Vc!Z$ZIw$H2n`YP1I(vp#3&YU>~Z*Q4ip7Y_u2e)D`FRv)+-ABKK zuMP-^`+Z`c-PYb#A&EElq+@k-m-IQg$hxfN_TjkrWQNniij?ALV|7JEMovynhK5a> zHi@PFT6;E9wZpz-p>Ofai?giH@a~s?u3XBa^hzNrHX&=Pl&h=jvg1dO9xW&<t9!At z_V+i}%uLNWbLIp@MRB>i9}lbFzq9yx(z`o58E*aS{}?Cvrt9NBzuZ~g{pZ3p#W(!B z&pN|6g+V5Fw})(dhgCzEctl2q#+NGFPoF<uTpw?*qN@6E_xpX-FLqwMcrjvso$bBq z_qJ=|_sgBP{XXaAk8b^a9j@JC4?jIUec|<24sP!3+&5y?ht_e=+whqGMBlT!8{=-w zEoO~h7IQ#;@z$W@ii`LbTxDGrv~ok%RjnRLV>Y&C#?5JGo5EJR*3{TAD43avg|CZo z^zh)&($-!WqO~#UDA$u`&kkL?w(aIlh7b3u-zUDhvNB?G8t<>~@B4jcn=vwMiPBx> zH<v3uK7K`r){&z}4<2e--&|tm|Ls+q=E1C*@-+g*?K=<dOb=hFEI9Fq$fBL%GtW49 zcyPRL%}DFYzw+p%fA9yV4cvBB_xxX3xTJad*;`kt^G)L2x|cWU(68iRH$kqSeu0bp z`dQsJ8U4@X)tSPz=Fa-*R$PLfJ*;k<E?wWo{cryV^^IR$4E3gjdds$Vs2pKZzS$}J zu1{~8Sj5i$-3E)!fmB}Jz1uX6r|c-deeAKF+Mz+m6_q+S9PpU-FMEdcuIQB;YVzk@ z3GDDt(lMB{be>t*4AT^bvga`eUrw40Qp&mEfXB8pep!ay>!w8<i<Jmm4zgv_+oaB& zhwl_Tcw?F(davRGi;|JHj=`j@;cUAXg_~>Z+&&yD5tx3MNjdU$+S)^V9=@}v`DfJ^ zt25n-Ptev)$6(Ul>4)dHB=cw2Y*Sx<^L)Ru(r1vduahzxS(QyZ^nXaUpN|JCR?{&s z;{LCHh4oxn><(__S1u6=DPn6v#VT@_u=|OHJ~(hQ8x+0A8(7`a9=9f#{<6N2a5L`R zp$;o%F^QSH-rTBj@-G|boSyn4{G8~@#{w6Xk1;84HY!Z8-|_urTWa;^1cM@Rx0G`c z2`OPZQ@HjRo)WGMU%m0~^3xz=-8)#_(w;}!t5?`xWY+%n@?wI4mu!27oS>M*OxdMD zjpmDOCf4s|i}s%aO2%N}r5E?rKD>UOZ@&D6#p@b(8z>bYVN%|_BtT<>#wlS}`}>EM zolye2;{cQLX45;X`I~-yKeCGXdExGR3Z*hiS?M|kM$ufYO!>XvoL-*2?ct+8#fndG ztFf5G%&iCC?G<?UtYJs9c){vEkkN;jlsA7nt@oztt8C5WFtC1i9Rs84h3ohvUR{)q zUmD}^buQb=#{z-YiH(8FKU~Yuzq4ZP@%#Nxyk%QEZgFz=ShYG#pC|G3aYxIQ?j4n@ z<6c>~T;YpIxRSX?@@eO(+q@B<AG@wxED&g&fmzTxap=VC5RmF+`}gnPiBqSH=0@)< zQZ2EP-4dl6wk9IcYTn$ry&WAKKYsp9Jk-LuV$B*BhAUUECO$he^TX%Q&JGR?K|w*c zgl2L?rCaxWu4c;a7kTh}mK7*)oU=+BC3)D+%(s`{UG~-~FHcWMSa{;>*|#&!tc%^f z;QH%@*I!SZFrlIH^D{;Uzxj5(S67EC+uF+Bym|A+zFO-qRkjQX=jK>O>?+apn`_nT zIa%%DWPdwGhBa|}r5FlcUD14Vcei-{-mhXkZug`*ZawP_oVQ|S_JQjw#X$ul|H`YZ z)8@_N>vdx`H#d)nj7+?>CDX&tFV95H(o*vBa{qRxg$@-J7Ft?b0Z~z1#_8u2Y;0uA z%*+B-hII7w9J#hOnvvo8xw(foBpzm92n-Cov8U3QLBZTyT-m*kg~7tgYEk<6c_&Vs zs950|{QT+7?d<co3y!tb$2&z{l?8i0qqK4Alqm;}9zFUcVr|&sB}<mLCYM;r1_cKP z#>MsBx)t^2&Q4)h*Yfwl%P+s&kbB#V;lMKA**?danV2TbnbXtO#`fg-b7f=W%O!m) zbmQcoc3L$H#xE@pcT-7Ut7o7%tM1fVr-d08CFJGhb>j9&1O^65SeNNAF)=L&TFJnW zeSMv4OpHuZb8}#5D66)%Hp786k((EUtrip#YBIl9!Q9NwZ+7<Q+LjiUmX;O;6&04$ z)YP@HQl&!7Rpp1DEU5eO?S|@<dwWzt>3MOO_>8Evd0UvZ^z_;`Z!YdWerBGnbYx_t zhp%sIb2GDzjSU0Cy?gg=Y)<DlGBy?z7G^FlFK1}*T6*Zy)6<2Yo_MOLt1n)(=#a_t z?a$85?e6K}Nl8gr;6L9EWTVph-VY4|Ooi;{!3lcBRaP-6DXzV7{eS+{T)BExQB#xi z;>C--$6VaqlMlD?3OFg~>hhMAl@$~fF)_sKF5_hgh=}OWjozl9rnc;2;o^%Os@~H; z)d535baeNuSyBuK?(MB!7_{=i>({rBJr%n9^ie|dN=3aJ6?-^b?XQChgIf!J-amRb zzsajLbahzbnHh!#5<D|Zv&FVV={|b=n2}-Q#*Gr@c`{}1?r`?HF-uEJuZiA%&ZPd+ zlam_y`t3`Xsy=%3C?P+ep8?d`h>4M5c<}T2{Nrb48Z+L}-Da44%wc87qWJxFKc3Cb zm#ezy!muHsg>AM`_s64*wsv)j4xkL{#dYxxcUYQ&ib~7-uqEujF7sR7DEK1uO}Fbv z!@avp`kYlPE&slGGS^~7@NDxtvSDIgYwk=}FBMVBQWcXZ&FZt=mGJ9<`ouZ4lD}4i zDu_j435|ivzsDcGlbD>@`1DSz;7Km-9xL1Sys#OzDGX)vb~Vj8ZN9Y6*+usVlQNS* z^$O{n-77ZKtdD!8;c~@E$3TgpELPXBe1`P5xjVZiyNZI!Qn7?a!N2pv=SeYCOMC0w zK7IXA$CCzDw=ea(qE`z@zwEoR$4Ey1N8~FDmoE%D28-fn%02A5cX#c_oAZi;+!YtG zgS2~vYjdBeza^U~UVm&qq=vZ5&fW9EJn7J{!zb5q+ubc*ce9^OX;EJztJ@b)LkARE zk*%U0Cm%mQC0H7wW3VXAhqJ8l);{fNZ%;nf>sis9)F^0r*lX#L*UGj9tDdNRdlEO* ziciqAnboa?e=^&*=IHl%5&O*f-IU7{U1WK=dtS(eh<WX?JH?#7J|yw@${<jN&EIgK z<6ZKe{l+%G%Z_)|RvHwDyIolb5<Kj*^1$rHHiZk<u$@%*Uef36vi}g1a#e%jE%gPZ z+m3eCR$Y<q0R_+h!%WIm3`>I=T?6N7{Mjis+qfGP!2jD>-L5bx?~xUJaY=ao`|FSW znjK$RxUh&x1a@nPaP2X@B|JU<-mznkT|u?=h65c@IUUQ)55>rCcy*xi&bsG%Jsw#) z28+agrHYrWeZ{xWcehHZ47gQsg{%2aRbt07^L<`B8)jMa1|3&?btNLfW$l)XE$XMT zUw!yqFYY1Rx*|BaQE+QeKz()3hcXwZSFR!V)=jtK3tWDbNja<8HE`aJ%GGk&{Lj5* zTUT68Y82esxNg4ev8?$AX32jJ42)I>+1APGcEv5<t#Z+>b&Q-V#h3OuUD4clpySqw z*PU__@89>7x^orm{ndA+u)}JundpsocXvO0`joZTt$C^UbTdo+mIH@cxffo{c<}yx z|C%*Amv>Ip4tH99*>UyNq@!J;%O2mnd9$OVqhbDhc{QII4W~|dF*wxL+NPbE!Dw%9 z|KrDx2aEgdvb<9kxj*U+l-qGR!J+09>&nFfrB-Y844Mx%{QUep`TDxp6KBu5CM78q z6&FkQ-rSlk{-w(HOO>sVkdQ(CJsXAtFE20OSoJk)$@1mO*4EOux97LVt!HOAGt-#e zd%E7ymBGsmGB2srd^)Ml(BM1U?BTh&))u9&M2w7$KYqO)zj(od2a3~uuh`xBU2=R$ zV^+=NGR|=Ad~ezI6~Td0y=#A0bbR?zl3^mXW9QC;r%rKQUgqmuR8+Ls@Z`ypGwf=u zM7UTF9X=cw7}$8}QqUaBVz!45A8y#VvC(U(Q%Z`;w!FJs3>M|@WZHP8*%&&yx(+=% zJDc&&-rtVCzPwUWQU$NC=|)6G9=vvKo9r~dD|T`6FBzKy72Z9l1cic@7e7D$p5O0w zTNFR@=~kOUOuduevkF@8l$Xcum#Kf)Dt=*Yw7G?)<-+>^|0;gJ-Tv^wgAW1SYFYk& zt8XTFb#PC*)Ey91=>5vVWp3N2Bb~wp<>jC<PEb&=v89D2BsBDHag%_reOT<mfGrt4 z{r&u6`f;F6TaT=@7$YNNz{(JYhOODxmGt!Z4jw$XV%4gqsZ&K8c3azVb92Y+tC9Tq z^XHE@o6i?~O#1x(mzRf!gNFx)lao_GSlG2k>WwcF_nJ8|Gend+=obntI_I33o!xzM zvU>8ZEtwj+x<{|AjTRIUY0=;JW6@2{tgW}USTi*|dX!{gZN0tX%$u8=Cr+K(dgV%p zg@wg|!-ts}9=v_qYg_$I!Pr>X(9p1;yquk(=KI}pP!+19tNZZ9i;m^<s<<R2CC{eK zwl3E*GBTPlablxg?XN9=BNek$pU;c^_-MhuANPBbVzkyz1m&d#b4)UY+<GJ$fBmXj zvV1u>Ii5RrZ2A1UTW?w}TnKpawd%?9=Zn{^<MZ_N?C9&uGf}g$l9D#dX%Jw_%*@=d zd2{pR$-)c|Zf(tGXgGApDd*-URzpL>-ahGGH=W2$EE_j&WN7&J_xIxijm!d03zjWo z%goID@a<cjia*=M!$0m8l&=7FFV4qGTztHs>hoE1-~4$$T6=qWudR!{{fA{L*Xppf zhh`WiCm2W=WM9+C*cx@@_;F^2q~v77ygL?e?(7tn>SZ%CGmF@sH#hU&larGJ<Kp<d zy}gedJ$mroJvoL6GiG$`*kQqN;PZL={+Y(<ObpM?&j&F;*_fkorg8eA`}g<XozJnT z{b|N+R%5o^kLv68blm5??(q0V>?;qKG~dG4*K{u~^%iG1ur_*o!Pi%zpizU_t7KbR zTJC%f?{(?-s{AU;^Y@2>TZMJf)o#b2<BD0PmVe)bXife9b9sW8nAi-X)UKAojwfrw z%d^kaoD#iVuXS|hC9jo_1!kH))qnqunL%4y`^wdi$CkM8C%+3Q|MFtdp-+1&gWMI3 zvQNeTdn;P-U0&h)ySvecEL=XNrzdw>I^Wr#{$0yY^!^kpzKJ)3{(W6v@B07W--nML zHSPcVHvh-3*Xx&_b5W_@+_q!U@?KTD+tq%rG+ffWr%s#J_Ucs@hvJ3B?))o13(Rbl zi`DrUfB4SF?RV2BR=#9wn!jbbmC(eSfp2eYWVW=l+)?}6-dySLgJyn}9+wokG@d%4 zY`*v-oAy5c&~N!L=(wVh^{HNI^GDtK`x*q8W}D?6TIxN0$vKx4v79v_Gt6U}m7kw3 z{Nl9pN@2&DzM{izyc@P}Ki)21*Rj~W-)tAFh@kD9>1<!!<D;{6Zdb3nS=}NOcw4h^ zclo^?H9w7Xbaf}rpRZr4rnLFu(MwF+^R7LpILN!?)p~E)_7iI~w6(N8oU?wfprOIB zZr!?N)xtu7zh_9jYO1r{ope*~-Jy<KlUJpmpEtv%(&$*feE;q3`G?P)^Lypvl6EvI zwXs#%G-A(Bk#F;Nwgv91UE1f8BIfgb;-+oe(kcaQmCm!Te0b|L+x)j@n=hV-l?bd3 z$h)^^=KgC6RY!Dmr&QEVVLvw|EU6savzy8Kbc$whLT)Z^baeEW7Z;hgZQB-<u4LqW zk@4?>!Z&@OzTN}9o)-%>d)<^@y?PZ88QHmNl~zdDC%#)Q{Y|Ywjppu^kN#}U`ZJ*x z)C$?O+{@QDx9G!m`2tUuw~r<73V*0u#2)tM^x=wn^%Xi(_~*H}oDHf4m-x50=YzT* z2M->!D0t9tWMjvf$w4a*L@(a^;Z++?&PlC<R|3RWDhr;R_NraJ&cIefOY6|?_xrRH zWt2XfpE@qcm}R}@&idm=8viCNeWcg3LRaeV$-vcDEoy&l35o2mVr_K-^)v&63c3H! zJ{Y}SaO>2phlg4j7#3gbn5^!vq@dvN$iwBC{-w?QPV=@LU&Z|VqtgZ{@F<UqiuI}o z4<79P@4VE>$A@QL)ho?TWx<*Dtxk=ui)}wXdX@a+U4&v9ucSh$iqg%%y?;KP-mqcE z@9p-@2ODnPin?**hKEe|i9ft$?|Tov+xsE)RHYAhabxdPuU8%}n-=XUeLd~v!L`xb z4YICia6T3&)ywbwcI4|Ts~n3{=N~F(>85^vch{GHSNVHTp9QSopud&ErOo{N>`E9z zx4&N6=d7ZA3RI^zpE=`G^7dBfAq$te=c5<f7B1c<_qtESL$=i;EF^fDPvhpz#taL# zY!Q(*&zoW)6!`qn`gffg_xDBoee8Osqo8$e@6sz1J5n4=zrBf^V_p8e`P-hodlNHH z+S|DF-(IBQZ?pN6`<j0rmF~$s26am!1Fl`U;;{Vk!Sm<&uU@_S<I`#VrRQ9(y!>Ht z^HF5I`Wov={qwV5fm+Vor{>vK&yu@w`*!lzS644BReW{ip3Kw6=&x)wwJ(~k1a_o2 z{(Q~9*+hb8#p=iR|8*=`Q23^?v+?hSe;JJHLhnwq;tOnFKF6w5>)ZSL{mkrq3TkR> z=WRZpVOB>Uw?ytv|M+Sx8@?`vbFLrz+OTd>?XVer>FMdpN=i(1b#*fg5}RHYpZxc+ zzh1$>;KCRAr9qvO)&1?h9O|@M)T(S7kT3A#lR|l-3;!Y+%OaJ1zg}sB3gY|s<@f(< zKYP|tf@g+(y<OScTdjBQ#HgsLDVdvZzc?dycbVh*>(1-1D=R88etveAnPHAyt<>u9 z^=?&FR+~0&-dOv)?8VnAh68gfi#L?K44PwA$~9Td_mIB)=TDcs^&dVtIoYD}lghrI z&!l(k+NES*u;K4~&P99&ww9awaE3kJo!q)3rtrUgx69nQMm9Dwd3kv!PMkRK@9*!2 zj~^G`=Y09{<%JBBg5u)lsZ&Ll`Oa>8^5ltI@}$X=4QqduXozsNw6`x_w1{cx(xn!~ z&v*_UKAf1G%skbr^~4E}oV&Y384BLrvAoQl>E`O1`1#pc-+kptNlHmcNw+dzy2xH` z5|rZLSYuTI>YvEYDmvK2T2Q$2_rfZ6ez_w}tlaOc7cN<H#6)h#Z$TYZRaIT7#g|`B zm^G{G)2E^<*RCZ!KQ~uHYuE2WA*;E3x3}doGvwUba<b&lb+f!XCzzBsFW<KP;Fb&j ze%ycTs_|=P`7y;^;!htQZvS{Id;QG37rzzu*IhoYFT&;8eeB1N9|<2H9R;<}Zfr~z z6cAvT`sy98&-Hb&-8*(zu&}c)UbTuVEG(?|nY?uwkC~a-g_kAY&rF*zp<%)VftHq* z1#8x<aeb}$sza>j!U3n12TuL10=3{GPKC$UivIroK7U)uti@rgrQh~-b$9!gf66#H z<L2gcWo>Qll`B^k6cjYHx3e=GxP4o?`r8}Fv^2G(q@)M$-rakVt){AKTkzeq{9Q~+ zN{WKAGIQ;(FPRssI;<A$+ZOww=*Q$6lR|sCQ?G*h2ot6pZs$J^>YIOk6?$g2Ie+!H zH#^^x*-AMY;<aDKT&g!>QwnENa<Xsy`3#e!%gcOcnB_+8)}6CrF{r<KO^bWYjpMNr zfyGnw_kI!DmVdwR(W4|5Ha4^0+kV6cuMCOUnk9OyN0M1yUfyS}m1xvjv5y}=cJ%kN zGxSKC^X28`UC1!0_<A*5`tIH2Pft#M`1<u|3#YJvxHx-tbu~lz{RIzSzU0i$&$qC) zPJVV~<}%BbD^~`rzFJURT+QOo7<l~3?%k$wm8V4Q|CC5RUl-yj+rHv(V8^t7?*h`C zoSY_1oTymJv#7mj@iIPVd%KQ>a(h2sJ$*3f_(DGiKYQz`LF@{PsyDw+YrW~rmT>*0 zwnc1F@GA=!SuXCL70t74OBmklic0!<&ju8g8af7x_|9$XR(-}-eE8qy>jjHGw1K)s zJYo_PdG;9H5aqkOW<$?z10`XQK*{E?#JDBabywn#Z|r^wiX4H61Q(-^ML#ANyh-dm z|Lxh<AIF$op7pT0sR+le)iL;=$$RN@{jrIk@3abTJkXJn5;cQ+&f8rNHt&7DgL^S( ztXW0*nilsSvrEkCT-pU6?wGE8Ea2J2>ZW3>&$9aB{ex`t-#*Mvys{Id`Pe0=&g<?~ z3eSsF|4l7+eP!V?DX~%T<b@-Re;-_1E<Nr2)x#N2^m?q+K~9a-5Uog_qJA!R?V)eM zOM{LpPGaHi@mO~3x8=dr`TT2c%x-6E4t!<d(k}}Vp5n!NujkW|qpwYWJU<B<iLH7A z@|c<1ZSh@o+ah+_^MAe58mNAlNqLfCt5f6lMSB9ieAdl5r7_Kl?_%;%Cgn*xAI8-h zY^e9IDRk(+{S`C>YHOxrpfodBL!_d5i?~<*J-48$g;xqY<fO$UCQ3@YxyRkiW}%?> zlqv6Op}5;zX6~L7Jdh4HcWP+)2j8ehkYa8ziHS2gq#raal-ny<|3YcidlL`Y_7@YA z8wF=h-0oKW@cAsOHUG;SuLO3;S%M7aEP2@(XxI1S*9F!)dqVX~Wfr-2vAU(C-Rt?} z^7Jj+8p}EQr6N$VE$7N4o_}Y#CU@?5FL)f()=0-d$@t4L4gV`wOdq_y(jDZkSd|iy z;F5M?b3VTdyG8Qf{oDazrS71BW8v;O;j`E75_2h=@;<Mf3Ls;fSlv>V-T75={4?|K zhqsfb7d(BIyYjKXR%?*cqAe~lJ8>M5wE?-H?D{gM+k30c_w3oT;9|yxl9!j%p0^4_ z>?lw)k>dUH_wU51Q&X+B7d`cQQDT*`HEK=7#zQ4G*4EO3f`T7De?Hv8DI5?PnQ8U? zvcJ8ek`hzYTCuvix*fH%>u0NEaf-7``|uVzUi#=Ha(~5CE55+rLTj&FxS(+7&YdHz z+~P;(T9>!%-@pIF$&;BTX~+9yliu9ecqqf9Yq5KO*M<!SiY)@m{pJcqtqr?!<x0n| zUsa&lw~5oItG5U!si?RtS+ayfu_a=2TJMj;^8XZC1QzYv7x&}mPf(r6rP#8g?(eRQ zjEstea}ORoxTe>8LML|J3Q@WD??v9{otkFFb+TbjUD>uO_Gwff@l7~C&-TOj@9qu` z3=9jx*T-$xu%TkbCAC*RbKdNFwAAtM$J#JMu65TvWZPHV-LmU|7h+y7FevEKB15aW zd@EP3)X>*=&&|~>dv~XE(IO>=2EXNpx8>euV3=)|Yb1OA^yy}&g$HIBCJTs&wDjBm z+wpa^%Bv^Oc+1<DwlZmde+?RI$%~3jaOzMyj;jvxVr7WfU8ei!(IWvFnVwavv?3xS z85uwW2R}X}{&_6_|G>k;?Hfy9hrPMKfB(%KPEJk^PEOEd*{M^fqIR1oZGL)VF?X8F zjz{y=y%+3{QZ1EPRCe(dQ}F7m%kBy?h=_|HzIk)!(d4xA^Lp2;(OI%=*`b368Rg{U z7#Q;N^WWUu%szAG%#O~^#@)Mjzsz=Zb6ZsY{@#X-8xQU%d@QwB^T~CLVc&_7zt?+| z--w7xIBNIq&<UF@yAC)xJA>u~=FjinwaZFHRaMc-YS+aXFJ5Hi+}ObA#Ia=Q(u3#D z@%j1rF*IafU&qM6#m&7j^|aWLBS%cE@=T;QrkoTya_krr!?7O8=J|EMIJa)ynvs!l zV4-t6Xyo^ssBmcz-`zC^*(Vs@ScG=nIX#tsp2L-cTX!7@(Gm??8`Wx<eC)vS<Ltr9 z{aPWDir}%|$H#h`XU~>4k>ZW3{Tf>H^Xc?uKPRjEx2<2V4;rf4vBToV?c2-@Lc+qR zH4p<sb#=8))E16~3m0<m@VEp72+XVfcJo_=b>MR2bjWCHPCjVfPHvWy$@iwIlAsuI zb8EX|*kLtC|J&=2H}Bb~oUi9%)U#dM=PcVZr>^XipNz}gGiMZ;6Y`T!#!8%&<L+58 z`DNZ(r<^6$d{3t;E`E>>DjFoYdsYa~f{ne(y*uQwFCxKZZUA)T)m?wu<NMFogGR1= zLE<4k5%m@~OAn@3f1b7XCqtn7aVF(gOLj#>CH%WAl@hnB(YbO@G4Dx!kg2n!9x-jL z4cB?B*OSu0>UO0wa^;i=i%ZP*qW6#7yR#b9V;2>Z2vq<6=Rlprh27HApxwTc>>#HG zM{0aLrolgFZ&2&p<=#vCoUbJ57$_a8s86r3C_DaXFL=<^=pd7_QiO(RgvirQ?o?9` zyU0%uI9+5JxO+Mn7?QX6AGm%}O#A!mz<kg=-iyYF1Q!;D04>pq>@DhgQ?Dm^hiwFv z@*S*hE^6>GSNnczUQh*cSDd@2!$+iR_xjeUA$Jb_m(!nQ#dooG<AIJ96G7F%vsVrJ zd~=S!t$KH8#p9$#!9ZoMR;F&Z>PJymWou@adA_o6`F?~+d6C+oyWnxxNBpr4XTh^` zrroS=E@z`mZV6vs{d&W{3{dqk>nM}*qBPd8+mBz4tNrkGVfTtXv7qWhEh53C#J_vd zo`TO|w%6p&O}F9-6xK0N+NEy4RrXlM@1(Q+KUb_+JK2gaP#Gk=-h8o5VC?mVtM+kI zthgXz`tQ|q4Da2SeBCFqU=MiQ)h8ms<%>gp&!+=#tC%10$4;=~n#jrB(;;_auYb9Z z)QV$mi>DWC{j}{$VF$y8105?`A!CRSPB%Ya+ULa5!0P6rdu9K0-eoN-CeDa!nbZ{> zbX*Y>2B&o;9x}oP-ps_PF}CaI$hfQQZP$SV4lHbJN@ix;bSHMMkp1%drl?_}-i;}N zJ%>I;t$ZAit{$~=;X=n9J9l<+i|ch*m%rnvudkmtdGh6qGkvnwM{aCPZoQZx;@&SO znwXgA#G&Xn*GhFu)Y=s*R|<ap`ZejviHS=3`u-e>E*>5o9-f|pP8=K4&Psi3xBuhF zq3BXqx9>vM*0sHwP0h{Gw(T8qZ(f~^KJ@wT+@y=gCI`7MbO*KlqW3a;`}?=MEpD7W zTYAl!H3qr2OjOj>&7~&KnbY&~@^a^e0idZvQBhIQFyu_*bOAXzK51#`6(L$KZf;4h zudQWhD1Lq}@y3QkhK3a@G;;3k*(v*KX2-4Bi*v<&`YVK@{Y4)C*_yTTae(@=thX(b zCJF7>v15kZ>Z@5Y)@5%RlfAsW7F>S$;otA~i}&o2iQQdx^xRx)K_MX~h6^uCZfwhy zo>%?OGUwJ7PFL6eKTU4@`BTFqV<E6@+cr>xyp2!x)~i;bz~7y{CCBaiM3WAdM}8>X zG}Vgl@3Aij*Kw4*)GY0BUw4&NttF$#2W{ZC<o!Ka@WAamH&EpwrWey;|NrOtj=sLO zsZ&K64xBv6+1lFr;_7N~BV%L3!bdJoo;*2lZQD!vSuHIrDk>@r4A9yoYhm2owI9D_ z=&$)*Zr2z8EAo}a7inRpIZOLKL~Pl0;7V5M;=6%`g@$wH&W*N7`}XGM4BKimh6msG z|F^x8WeV%=Ffu%O`ZV!)pDY7Ic6N3|bTl(VQ*-mfcklZ4?6J9g?k%{I*;)LY$zk*J z&J};pNEtC~ZsVC?#a5cMwvJcF#6+ZOuiVd{KUb_??Hm*&v}o~S>E6!1zGH2?(h4mO zQBhKF-n@D6<OvJIj<UB>3=1}F5QyHE<LKk#v-q*wVnOd|I);01Z&<U2hpm~Bq2S{q z*B4)_cI@1_wQ#44ONw69US_0TO!L8p$jxbvtFJ2A*~u{oNK1DYK0cQC{M_6N8781= zaaZYUpU>+4^LjpgDvF4T0u70O`}U2Y;m8q}BS()iGR&+0S9xZRr7#1i3(If!gF$O5 zmzbDXM}L2NbMs+m{W%>SPka?;RUbUX)6&wSpseitO2WlulhX3}bN2|d<5@8_^X+bd z>hJG(!^5weoz>OWcAjrnd&_)+%Z_P%$78NB7U;@eJ-lY=f&~oM*T?5Scj<4tDR$%V zgv+-*<Xf^|R-8$dWw?GRaOUNxckTO~RFIaf&H8wnxA83F^-CXToVH?Ew{)@5@~wCH z<C#HA+aB?OR<EtREOui}<mLx2UU1CyyZ)w9kwr`|hNIVwxxT(WB{kLdS&Q!X1zVRa zS#sdqIljimMh+gHBlG|Nd47D(hSaXEt`A=>`yW1joIPr-n0ENO93wZq>q}L$C!Q|d zS^L}U(4j*D0s;(c!@BwH|7<uZu`_1+v^B1J*I5~O*qGV)WI8@RK5lNQK0W(WfQx$R ziwlYqXU^2z5|z98dHI_gipTn7wR_waTU32n!NJX4eTFmoony9T#fJqI-)^SM7q7au zyXNPn6KBqdIB^6<MrzvF*tq!l`7Jgq+Nq}-*&Va4`(lPjXlSSthvKVOuUsM`W>{2x z$v~18aH{!sGhIndZQ0E|MLX5_7hisNpp93$?EZOCZ$H0dUtV5b7@#2_C)d~2#WiKx zG`nNqkT}sgd9rX*QxgM&ZPk~C{(gQhFE6RyiKk05Pewg)S{SgP`ujVEhF`y`u3W$F zyfDB-N_qP0OOv*qa4jq}EPHdK(X#lNNp+mw^(ToXx0tFYAQdY&|1mH%EOzf#Qcz$x zZ}XX_sHkX3-)9F_R#uC`M=Wf-QdeH=kXe6v?=_|q&v!}iu(7bQ9lCf?FfcH1(>^`^ zC~k%R*Po6_=RbJdZ{HSl^G|=FW$ff_NsV`#GoGEDS@d^yR6u}$UhJ-xH*a$Ge7luB z!=iA}m5J|G@7g6bS>69wm#B6^cJ}QTbMCA={r8&E^4Cd;iH(!h{TJ=nA@TC%ON-iH zC5t~7?b73q>Xun|`rF&vpvjmA4<1zaO}uQHsV#QHcE!~!5mC{Za^m1tY2+c7-1O}1 z>|a|??kId55FdX(@qKAne0=}O$?A`nPLDhE`Mmx8Cwu0sJKYSj?Nj$W+v=nrA0FP= zTV1~RanUZl{X4ElJt>Iold){dy}j+=)vKbHm-!xUWM=>H>9oGNl=AWHOHn!-syi@C z3YF3|W;=_XcD>9`e);mHfmH8{tE;CUSvgC{DdF3jn;$-YJh;fUdzLITVxI1rGIi?0 zMT?jiK$CfIZg1zmckf>Bu@g^sZAsBB+My?3`$e#;yE`x-pkc`pl^zL0ru%ijb3y4y zjo(?dlwswOBS#ceRavjEk8h7!du;x{FX~sWU&k?nbENbATgU95|9-#!cr`rUacR)P z?Ca}1yu4mb*zaXg^n?R6#P{RpPelU*0frCP_x}@ZZEa<EQW_hzNo;4#^h=k5&de}m zW@~12TPzs7-0$d*kB<fA<@p`rpqVbSPiFne!w=?KfR?NN&~{N$!n0C!;u0<{u6LE< zf8(`vbxrRUTC1wEa<wwq|M}4T^3RWtk0(x=bV+Y!Ol++DY5SPk=X1*!T+RA;&icKH z)xYoi|I2zU6|(>PW%A1ze#@O59S_#-zc<ISc*5-2)%Ok8@A<^F@BhEwYR7|vf@bYq z{PS?s)~u<K`)X%fv585XJhUnG^oLig*Uth~rhlI5*H1JE{46CU6=n17_4@s8VPRr+ zKc7ridwy%vzW+bx|8HEsK7aMIE#l1wFTD8}xjF6O=kxaUXEIl8PG}5lFZ}q(l}E-x z;CR1$eo@}%XJ;j>%htRsbZ}sJ^XAQl^7r@DlI!<2pE_~kz=wy285!#TJeCKo2knuw zm0GuM-HlDD-1lm~$Nu>FGx5!hjmzr)@G{J=`NRn-xWvTTK0iOtyrWvY{p4hIWfK#T zg$oy26g}y9`TE>kYw^YI{h%U3I#tl=!QAqDjeg6U=g*fn%e%wz;9aaxR#uil#s!7T z8{?(z>uidOiVEJ{*{SyY+nbxs9UUAjtgMNrrf9BMw~md0N8V0GQc{wk!8H5Yfy<YL zK_h^x*YE3M<rbUqZ27WfNgo~@EGRC{-Xgtv&6*=iJSRVR{kr?cjR=WeHwK5KB&9!p z{w%osazn+(q{~0v-QE4Tncwcf0>|cpFE0Z3e7{%y;p^Ad{rm0j)&0(8Xz1<bjfst& zIB(v)6j>Mkw;$`9xG&Z}dH$SxZJ74CdA8b2%*;nSh1F;E?5y~>h=YsE$oAjU&$EJp zgO6@VJlrDSWb^yY=0^_>Hal@B&a3(485kC()*|re$H&J?Mn*v#ibvYz>oh>~Ac5iG z>VETVG<TJ~)hc^;N3%sB>Cur+P@8MnPw#0upypoP->S>$x<`*3>5#Ag6Ud=>WU2S` zmhIcOzkJQc%G!DB)-AQ0`~Ux|Hq5%Baq+?hg>^AIm11IIjyyd*y;Z=e<n^_+H=n26 z-Bp_O@6XSqzrVgF{rd9Ki9>Nq)Y=p0&x@y~rY60(uyE0?T~Qp0M=mY(J_^c}iY)>Y zCr$)8chQn1E;c`(OkT8SPt298S34D*+d6)Je$K7fGQ%*rt>R&;c+0^A0ZU8ELwE1q z{W2#iI=UOw&TVUB1G#5o=4G|Opdhe&wq{><sjb~x@$2Ps-}Ax@8w(%1DYghaefsp! z<;%emy&w)~QS{pMlao{*UG}%nEnDlt-)wC#DJC*f64XGf`TchLvfFxpqcSr!e|>p* z`N@w93!NDm?En2p?%sdh_`D6Xx3~9?pFa;K7|bXut*fi+jZ5z8>@@uQ{o-PGMuy+t z-X3mZ<ra{T=y>rWBP1+r(W+Hk2ky#-USAhGyJr94c7A0wH8wFZu@|?u>nkZM&%T>z z_WY=Lyh~J+)Ths%eUD$dbZNru+1;R7^~T0V&~jkC*j-mL;;O2wl$4bpK6t?3?d|O| z(`f0%pPj<$PAfwWUB50KzCO-WmY<s&)B)Gg(Gd_9Za&=3|M<tp#}alm7Qeo{1kD~y zkE`NbxpF1LyL(q_KRr3g@a`VNcJJwWGmO*uw&mPpl9rafe7&@!1QeywThbO?d$Hxa z(~`c!=g#rz>FI&W!Yxs`lRtyn)o*WapFQ{Yzu)g4>(~F}4q5g8PkiLkpqX`lzuvCA zcI_Gi!_pv5Z*T9*w(q8Rr7pg{zOMHD?)ZP(FJ>IMxY&KR?y9`IyG&>M+Z>sDD@055 zUe#;eKX3Evv%h>@v}n<U88a%*SMvXwx8H_+zwGbZzyE5Abp5Vds~5ZLz@bBjrkL8= z%KrZT-aRNt2s9MI%+ANKqndYC^|=l?vGN2JOKolKKcDCSKhnl4{o!u;eN)T#FJD%k zFMrT0XDel%cgLZuY}-rlYI|d2#yhod*XY(&#%4V}*4xq1VNkn1bamLwV{>dOKQ+du z-`J3NV@IJf4;%BTQ>Q@9%MBYgh&``(IyHPl{(ZYCQ>QXA96Efs@W~0m%dd;JZrys~ z^yzF9bsZg^c0Sodz0&3vE?l_q=(&Qet?bXAKMkaML58*QNH#e(vpu}9(0NDQUn?Fq zW+7o=K}pG*DS4%(rg``Fbb^|{N4rFO=gr=;XOGz4%TrctjYznnSn~QB@9pjRps9$7 zbLQN6G6yt%)c!4h@Be?lHT3k%Z2x^Y%>VJ%>-EVmE-YLo>y&+WXEA#-8}Ff0Q?)nN z{48QP@O*xKUr!IukKex!U%!66<oWyGGfgsu{`~zbC?L=vsO$!6r5qFXx4HO8CM+xr zG#I)rb~hWti_6RT|NZ;EACzkU{rktk&wu>p=5*<^FR!nculxV|e)5F{juNI>A~m(O ziRtOz54;x@6m0x{ubSU#E*~=+&jC>R_~c1TzuhkmfBU~COiWCm!Mo@4s@I*ovuc%A zkG#FyzkmNadU{N3(+}LW{@gERy6R}QT9)R$-gzmBudl5=apugCNvhsU`joh@UcLI_ z-d<}dDJjtUqcgKixhJdnR=u(Gn`6<KdV1Q!+4=h_-`t;lM!q~j1tit9diClnhAEHx z>;EV-F*Ez_UfC(E?ltN0udlCvy#N0%-*cb&O#64>og_QJ4RFl8CI9W0ui2<l{@h|M Q0|Nttr>mdKI;Vst0Nf+@>Hq)$ literal 0 HcmV?d00001 diff --git a/test1.py b/test1.py new file mode 100644 index 0000000..5bc93fb --- /dev/null +++ b/test1.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +""" +Created on Mon Oct 23 19:43:47 2023 + +@author: oscar +""" + +import numpy as np +from collections import Counter +import read_cifar + +def distance_matrix(M1,M2): + # dists(i,j) = dist entre ième ligne de M1 et jème ligne de M1, soit la racine de sum((M1i,p - M2j,p)²)) + # qu'on peut simplifier en sum(M1i,p²) + sum(M2j,p²) - sum(2* M1j,p * M2i,p) + + l1=np.shape(M1)[0] + l2=np.shape(M2)[0] + Vect1=np.sum(M1**2,1) + Vect2=np.sum(M2**2,1) + + Mat1=np.tile(Vect1, (l2,1)) + Mat2=np.tile(Vect2, (l1,1)) + Mat3=2*np.dot(M1,M2.T) + + dists=np.sqrt(Mat1.T+Mat2-Mat3) + + return dists + +def knn_predict(dists,labels_train,k): + labels_predict=np.array([]) + size_test=np.shape(dists)[1] + for j in range(size_test): + list_arg_min=np.argsort(dists[:,j]) + labels_sorted=[labels_train[i] for i in list_arg_min] + k_labels=labels_sorted[:k] + count = Counter(k_labels) + + labels_predict=np.append(labels_predict,count.most_common(1)[0][0]) + + return labels_predict + +def evaluate_knn(data_train,data_test,labels_train,labels_test,k): + dists=distance_matrix(data_train,data_test) + labels_predict=knn_predict(dists,labels_train,k) + count=np.sum(labels_predict==labels_test) + return count/np.shape(labels_predict) + +if __name__ == "__main__": + file = "./data/cifar-10-python/" + data, labels = read_cifar.read_cifar(file) + data_train,labels_train,data_test,labels_test=read_cifar.split_dataset(data,labels,0.9) + print(evaluate_knn(data_train,data_test,labels_train,labels_test,20)) \ No newline at end of file -- GitLab