Skip to content
Snippets Groups Projects
Commit 81097624 authored by MSI\alber's avatar MSI\alber
Browse files

25/10 It seems working but accuracy 0.35

parent 00589b68
No related branches found
No related tags found
No related merge requests found
import numpy as np
from read_cifar import *
from collections import Counter
def compute_distance(m1, m2):
if m1.shape != m2.shape:
raise ValueError("Dimensions must be identical")
#distance = np.linalg.norm(m1 - m2)
x = (m1 - m2) ** 2
y = np.sum(x)
dist = np.sqrt(y)
return dist
def distance_matrix(data_train, data_test):
dists = []
......@@ -21,11 +21,48 @@ def distance_matrix(data_train, data_test):
return dists
def knn_predict(dists, labels_train, k):
predictions=[]
for distances in dists:
min_indexes = np.argpartition(distances, k)[:k]
possible_pred = labels_train[min_indexes]
counted = Counter(possible_pred)
pred = counted.most_common(1)[0][0]
predictions.append(pred)
return predictions
return dists
def evaluate_knn(predictions, labels_test):
sum=0
for i in range(len(predictions)):
if predictions[i] == labels_test[i]:
sum+=1
return sum / len(predictions)
'''def evaluate_knn(data_train , labels_train,data_test ,labels_test, k):
return'''
def main():
folder_path = 'data/cifar-10-batches-py'
data, labels = read_cifar(folder_path)
print((data.shape))
print((labels.shape))
data_train, data_test, labels_train, labels_test = split_dataset(data, labels, 0.9)
print("Training set shape:", data_train.shape, labels_train.shape)
print("Testing set shape:", data_test.shape, labels_test.shape)
dists=distance_matrix(data_train, data_test)
prediction=knn_predict(dists, labels_train, 4)
accuracy = evaluate_knn(prediction, labels_test)
print(accuracy)
def knn_predict(dist, labels_train, k):
return
if __name__ == "__main__":
main()
def evaluate_knn(data_train , labels_train,data_test ,labels_test, k):
return
\ No newline at end of file
......@@ -3,7 +3,6 @@ import numpy as np
import os
from sklearn.model_selection import train_test_split
def read_cifar_batch(batch):
with open(batch, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
......@@ -21,30 +20,11 @@ def read_cifar(path):
data_batch, labels_batch = read_cifar_batch(path + '/' + batch)
data.append(data_batch)
labels.append(labels_batch)
return np.array(data, dtype=np.float32).reshape((60000,3072)), np.array(labels, dtype=np.int64).reshape(-1)
return np.array(data, dtype=np.float32).reshape((60000, 3072)), np.array(labels, dtype=np.int64).reshape(-1)
def split_dataset(data, labels, split):
data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=1-split, shuffle=True)
return data_train, data_test, labels_train, labels_test
def main():
folder_path = 'data/cifar-10-batches-py'
data, labels = read_cifar(folder_path)
print((data.shape))
print((labels.shape))
data_train, data_test, labels_train, labels_test = split_dataset(data, labels, 0.9)
print("Training set shape:", data_train.shape, labels_train.shape)
print("Testing set shape:", data_test.shape, labels_test.shape)
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment