Skip to content
Snippets Groups Projects
Commit 00589b68 authored by MSI\alber's avatar MSI\alber
Browse files

Read cifar finished and distance matrix computed

parent 06fb753f
No related branches found
No related tags found
No related merge requests found
import numpy as np import numpy as np
def distance_matrix(m1, m2): def compute_distance(m1, m2):
if m1.shape != m2.shape: if m1.shape != m2.shape:
raise ValueError("Dimensions must be identical") raise ValueError("Dimensions must be identical")
...@@ -10,7 +10,22 @@ def distance_matrix(m1, m2): ...@@ -10,7 +10,22 @@ def distance_matrix(m1, m2):
return dist return dist
def knn_predict(dist, labels_train, k, ): def distance_matrix(data_train, data_test):
dists = []
for test in data_test:
dist = []
for train in data_train:
dist.append(compute_distance(test, train))
dists.append(dist)
return dists
return dists
def knn_predict(dist, labels_train, k):
return return
def evaluate_knn(data_train , labels_train,data_test ,labels_test, k): def evaluate_knn(data_train , labels_train,data_test ,labels_test, k):
return return
\ No newline at end of file
import pickle import pickle
import numpy as np import numpy as np
import os
from sklearn.model_selection import train_test_split
def read_cifar_batch(batch): def read_cifar_batch(batch):
with open(batch, 'rb') as fo: with open(batch, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes') dict = pickle.load(fo, encoding='bytes')
data = dict[b'data'] data = dict[b'data']
labels = dict[b'labels'] labels = dict[b'labels']
return data.astype(np.float32), np.array(labels, dtype=np.int64) print(dict[b'batch_label'])
return data, labels
def read_cifar(path):
batches_list = os.listdir(path)
data, labels = [], []
for batch in batches_list:
if(batch == 'batches.meta' or batch == 'readme.html'):
continue
data_batch, labels_batch = read_cifar_batch(path + '/' + batch)
data.append(data_batch)
labels.append(labels_batch)
return np.array(data, dtype=np.float32).reshape((60000,3072)), np.array(labels, dtype=np.int64).reshape(-1)
def split_dataset(data, labels, split):
data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=1-split, shuffle=True)
return data_train, data_test, labels_train, labels_test
def main():
folder_path = 'data/cifar-10-batches-py'
data, labels = read_cifar(folder_path)
print((data.shape))
print((labels.shape))
data_train, data_test, labels_train, labels_test = split_dataset(data, labels, 0.9)
print("Training set shape:", data_train.shape, labels_train.shape)
print("Testing set shape:", data_test.shape, labels_test.shape)
if __name__ == "__main__":
main()
batch='data/cifar-10-batches-py/data_batch_1'
data, labels = read_cifar_batch(batch)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment