Select Git revision
main.py 1.59 KiB
import read_cifar
import knn
import matplotlib.pyplot as plt
import mlp
import time
split = 0.9
d_h=64
learning_rate=0.1
num_epochs=100
batch_path = "data/cifar-10-python\cifar-10-batches-py"
data, labels = read_cifar.read_cifar(batch_path)
data_train, labels_train, data_test, labels_test = read_cifar.split_dataset(data, labels, split)
"""k_values = range(1, 21)
accuracies = []
times = []
for k in k_values:
start_time = time.time()
accuracy = knn.evaluate_knn(data_train, labels_train, data_test, labels_test, k)
end_time = time.time()
execution_time=end_time-start_time
times.append(execution_time)
accuracies.append(accuracy)
print(f"Accuracy for k={k}: {accuracy:.2f}, Time: {execution_time:.2f}s")
plt.figure(figsize=(8, 6))
plt.plot(k_values, accuracies, marker='o')
plt.title('KNN Accuracy vs. k')
plt.xlabel('k')
plt.ylabel('Accuracy')
plt.xticks(k_values)
plt.grid(True)
plt.savefig('results/knn.png')
plt.show()
plt.figure(figsize=(8, 6))
plt.plot(k_values,times, marker='o')
plt.title('Execution time vs. k')
plt.xlabel('k')
plt.ylabel('time')
plt.xticks(k_values)
plt.grid(True)
plt.savefig('results/time_knn.png')
plt.show()"""
train_accuracies,test_accuracy = mlp.run_mlp_training(data_train, labels_train, data_test, labels_test, d_h, learning_rate, num_epochs)
def plot_learning_accuracy(train_accuracies):
plt.figure()
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies)
plt.xlabel("Epoch")
plt.ylabel("Training Accuracy")
plt.title("MLP Training Accuracy")
plt.savefig("results/mlp.png")
plot_learning_accuracy(train_accuracies)