diff --git a/main.py b/main.py index 8eeff9748d0fd516160a59126f51497c7a9fe877..9d4bb938cd242ed1ceb5936adb3debee1a99ffd8 100644 --- a/main.py +++ b/main.py @@ -14,7 +14,7 @@ batch_path = "data/cifar-10-python\cifar-10-batches-py" data, labels = read_cifar.read_cifar(batch_path) data_train, labels_train, data_test, labels_test = read_cifar.split_dataset(data, labels, split) -"""k_values = range(1, 21) +k_values = range(1, 21) accuracies = [] times = [] @@ -45,7 +45,7 @@ plt.ylabel('time') plt.xticks(k_values) plt.grid(True) plt.savefig('results/time_knn.png') -plt.show()""" +plt.show() train_accuracies,test_accuracy = mlp.run_mlp_training(data_train, labels_train, data_test, labels_test, d_h, learning_rate, num_epochs)