Skip to content
Snippets Groups Projects
Select Git revision
  • 00beb9f4639487b70b8c765166b4c3d537be18e4
  • main default protected
  • master
3 results

main.py

Blame
  • main.py 1.59 KiB
    import read_cifar
    import knn
    import matplotlib.pyplot as plt
    import mlp
    import time
    
    split = 0.9
    d_h=64
    learning_rate=0.1
    num_epochs=100
    
    
    batch_path = "data/cifar-10-python\cifar-10-batches-py"
    data, labels = read_cifar.read_cifar(batch_path)
    data_train, labels_train, data_test, labels_test = read_cifar.split_dataset(data, labels, split)
    
    """k_values = range(1, 21)
    accuracies = []
    times = []
    
    for k in k_values:
        start_time = time.time()
        accuracy = knn.evaluate_knn(data_train, labels_train, data_test, labels_test, k)
        end_time = time.time()
        execution_time=end_time-start_time
        times.append(execution_time)
        accuracies.append(accuracy)
        print(f"Accuracy for k={k}: {accuracy:.2f}, Time: {execution_time:.2f}s")
    
    plt.figure(figsize=(8, 6))
    plt.plot(k_values, accuracies, marker='o')
    plt.title('KNN Accuracy vs. k')
    plt.xlabel('k')
    plt.ylabel('Accuracy')
    plt.xticks(k_values)
    plt.grid(True)
    plt.savefig('results/knn.png')
    plt.show()
    
    plt.figure(figsize=(8, 6))
    plt.plot(k_values,times, marker='o')
    plt.title('Execution time vs. k')
    plt.xlabel('k')
    plt.ylabel('time')
    plt.xticks(k_values)
    plt.grid(True)
    plt.savefig('results/time_knn.png')
    plt.show()"""
    
    
    train_accuracies,test_accuracy =  mlp.run_mlp_training(data_train, labels_train, data_test, labels_test, d_h, learning_rate, num_epochs)
    def plot_learning_accuracy(train_accuracies):
        plt.figure()
        plt.plot(range(1, len(train_accuracies) + 1), train_accuracies)
        plt.xlabel("Epoch")
        plt.ylabel("Training Accuracy")
        plt.title("MLP Training Accuracy")
        plt.savefig("results/mlp.png")
    
    plot_learning_accuracy(train_accuracies)