# Image classification Antonin DELORME - TD 1 sur le cours Deep Learning ## Pepare the CIFAR dataset ### 2. ```rb def read_cifar_batch(path): """ path = "data\cifar-10-batches-py\data_batch_1" par exemple """ with open(path, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') labels=list(dict.items())[1][1] #labels[i] est le label de l'ième image data=list(dict.items())[2][1] #data[i] sont les 3072 pixel de l'image i return (labels,data) ``` ### 3. ```rb def read_cifar(path): """ path="data\cifar-10-batches-py" par exemple """ (labels,data)=read_cifar_batch(path+"\\test_batch") for i in range(1,6): data=np.concatenate((data,read_cifar_batch(path+"\\data_batch_"+str(i))[1]),axis=0) labels=labels+read_cifar_batch(path+"\\data_batch_"+str(i))[0] return (labels,data) ``` ### 4. ```rb def split_dataset(labels,data,split): split=round(split*len(labels)) test=[] while len(test) != split: Nb=rd.randint(0,len(labels)-1) if Nb not in test : test.append(Nb) train=[i for i in range(len(labels)) if i not in test] data_train=data[train] data_test=data[test] labels_test=[] labels_train=[] for i in test: labels_test.append(labels[i]) for j in train: labels_train.append(labels[j]) return(data_train,labels_train,data_test,labels_test) ``` ## KNN ### 1. ```rb def distance_matrix(a,b): at=np.transpose(a) bt=np.transpose(b) som_carr_a=np.sum(np.square(a), axis=1, keepdims=True) som_carr_b=np.sum(np.square(b), axis=1, keepdims=True) prod=np.dot(a,bt) return ( np.sqrt(som_carr_a + np.transpose(som_carr_b) - 2 * prod) ) ``` ### 2. ```rb def knn_predict(dists,labels_train,k): predict=[] for i in range(len(dists)): Glob_dist=[] Glob_min=[] for j in range(len(dists[i])): Glob_dist.append(dists[i][j]) for p in range(k): m = min(Glob_dist) index = Glob_dist.index(m) Glob_min.append(labels_train[index]) del(Glob_dist[index]) Temp=np.bincount(Glob_min) predict.append(list(Temp).index(max(Temp))) return predict ``` ### 3. ```rb def evaluate_knn(data_train,data_test,labels_train,labels_test,k,dist): pred=knn_predict(dist, labels_train, k) tot=0 bon=0 for i in range(len(labels_test)): if labels_test[i]==pred[i]: bon+=1 tot+=1 print("Accuracy :",bon/tot) return bon/tot ``` ### 4. 