import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import accuracy_score
from read_cifar import read_cifar_batch, split_dataset
import matplotlib.pyplot as plt

# Charger CIFAR-10 depuis votre source de données
X,y = read_cifar_batch("data/cifar-10-batches-py/data_batch_1")

# Diviser les données en ensembles d'entraînement et de test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)

# Prétraitement des données
# Vous devrez redimensionner les images, les normaliser, etc.

# Définir l'architecture du réseau de neurones
input_size = 32 * 32 * 3  # 32x32 pixels et 3 canaux (RGB)
hidden_size = 64  # Nombre d'unités dans la couche cachée
output_size = 10  # 10 classes dans CIFAR-10

# Initialiser les poids et les biais
np.random.seed(0)
weights_input_hidden = np.random.randn(input_size, hidden_size)
bias_input_hidden = np.zeros((1, hidden_size))
weights_hidden_output = np.random.randn(hidden_size, output_size)
bias_hidden_output = np.zeros((1, output_size))

# Hyperparamètres
learning_rate = 0.1
num_epochs = 100
y_print,x_print,y2_print=[],[],[]
# Entraînement du modèle
for epoch in range(num_epochs):
    # Forward pass
    hidden_input = np.dot(X_train, weights_input_hidden) + bias_input_hidden
    hidden_output = 1 / (1 + np.exp(-hidden_input))  # Fonction d'activation (sigmoid)
    output_layer = np.dot(hidden_output, weights_hidden_output) + bias_hidden_output

    # Calcul softmax
    exp_scores = np.exp(output_layer)
    probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)

    # Calcul de la perte (cross-entropy)
    num_examples = len(X_train)
    corect_logprobs = -np.log(probs[range(num_examples), y_train])
    data_loss = np.sum(corect_logprobs) / num_examples

    # Calcul du gradient
    dprobs = probs
    dprobs[range(num_examples), y_train] -= 1
    dprobs /= num_examples

    dweights_hidden_output = np.dot(hidden_output.T, dprobs)
    dbias_hidden_output = np.sum(dprobs, axis=0, keepdims=True)

    dhidden = np.dot(dprobs, weights_hidden_output.T)
    dhidden_hidden = dhidden * (1 - hidden_output) * hidden_output
    dweights_input_hidden = np.dot(X_train.T, dhidden_hidden)
    dbias_input_hidden = np.sum(dhidden_hidden, axis=0)

    # Mise à jour des poids et des biais
    weights_input_hidden -= learning_rate * dweights_input_hidden
    bias_input_hidden -= learning_rate * dbias_input_hidden
    weights_hidden_output -= learning_rate * dweights_hidden_output
    bias_hidden_output -= learning_rate * dbias_hidden_output

    x_print.append(epoch)
    y_print.append(data_loss)
    predicted_class = np.argmax(output_layer, axis=1)
    y2_print.append(accuracy_score(y_train, predicted_class))
    # Affichage du loss à chaque époque (pour le suivi)
    if (epoch + 1) % 100 == 0:
        print(f'Époque {epoch + 1}: Loss = {data_loss:.4f}')

# Évaluation du modèle
hidden_input = np.dot(X_test, weights_input_hidden) + bias_input_hidden
hidden_output = 1 / (1 + np.exp(-hidden_input))
output_layer = np.dot(hidden_output, weights_hidden_output) + bias_hidden_output
predicted_class = np.argmax(output_layer, axis=1)
accuracy = accuracy_score(y_test, predicted_class)
print(f'Précision sur l\'ensemble de test: {accuracy:.4f}')

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

ax1.plot(x_print,y_print)
ax1.set_xlabel('epoque')
ax1.set_ylabel('loss')
ax1.set_title('evolution de la fonction loss par epoque')
ax1.legend()

ax2.plot(x_print,y2_print)
ax2.set_xlabel('epoque')
ax2.set_ylabel('accuracy')
ax2.set_title('evolution de la accuracy')
ax2.legend()
plt.tight_layout()
plt.show()