diff --git a/README.md b/README.md index 144715d30927deead5cf5aff9780b99bae70249f..259aac4b5531b8ab7336e4ce6702e2d0009281ea 100644 --- a/README.md +++ b/README.md @@ -58,4 +58,178 @@ classes = [ "ship", "truck", ] + + + for batch in train_loader_cifar: + x,y = batch + print(x.shape) + print(y.shape) + + x, y = x.to(device), y.to(device) + y_hat = model_cifar(x) + # Calculate the batch loss + loss = criterion(y_hat, y) + print(loss/128) + max_value, predicted_label = torch.max(y_hat, 1) + print(predicted_label) + print(y) + print(model_cifar) + break +``` + +Then, in order to put the ViT model on CIFAR dataset, we need to initialisate : + +``` +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print( + "Using device: ", + device, + f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "", +) + +model_cifar = ViT( + (3, 32, 32), n_patches=8, n_blocks=4, hidden_d=16, n_heads=2, out_d=10 +).to(device) #Input size: 3 * 32 * 32, block 4, 32 is divisible by 8 + +N_EPOCHS = 5 +LR = 0.005 +``` + +And then train and validate the model : + +``` +from tqdm import tqdm + +optimizer = Adam(model_cifar.parameters(), lr=LR) +criterion = CrossEntropyLoss() +val_accuracies = [] +train_loss_list_cifar = [] # list to store loss to visualize +valid_loss_min_cifar = np.Inf # track change in validation loss + + +for epoch in range(N_EPOCHS): + + #on place le modèle en mode train + model_cifar.train() + train_loss = 0.0 + for batch in tqdm(train_loader_cifar): + # Forward pass: compute predicted outputs by passing inputs to the model + x, y = batch + x, y = x.to(device), y.to(device) + y_hat = model_cifar(x) + # Calculate the batch loss + loss = criterion(y_hat, y) + # Clear the gradients of all optimized variables, before the backward pass + optimizer.zero_grad() + # Backward pass: compute gradient of the loss with respect to model parameters + loss.backward() + # Perform a single optimization step (parameter update) + optimizer.step() + # Update training loss + train_loss += loss.detach().cpu().item() / len(train_loader_cifar) + train_loss_list_cifar.append(train_loss) + print(f"Epoch {epoch + 1}/{N_EPOCHS} train loss: {train_loss:.2f}") + + #A la fin du train d'une epoch, on evalue le modèle sur le jeu de validation + # Validate the model + model_cifar.eval() + with torch.no_grad(): #on se place en inférence + correct, total = 0, 0 + valid_loss = 0.0 + for batch in tqdm(valid_loader_cifar): + x, y = batch + x, y = x.to(device), y.to(device) + + # + # TO DO: implement the computation of the loss and the accuracy (correct) + # + + # Forward pass: compute predicted outputs by passing inputs to the model + y_hat = model_cifar(x) + + # Calculate the batch loss + loss = criterion(y_hat, y) + # Update average validation loss + valid_loss += loss.item() / len(test_loader_cifar) + + # Compute accuracy + max_value, predicted_label = torch.max(y_hat, 1) + total += y.size(0) #Iteration on the number of sample processed + correct += (predicted_label == y).sum().item() + val_accuracies.append(correct / total) + print(f"Validation loss: {valid_loss :.2f}") + print(f"Validation accuracy: {correct / total * 100:.2f}%") + + + # Save model if validation loss has decreased + if valid_loss <= valid_loss_min_cifar: + print( + "Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...".format( + valid_loss_min_cifar, valid_loss + ) + ) + torch.save(model_cifar.state_dict(), "model_cifar_transformers.pt") + valid_loss_min_cifar = valid_loss + +``` + +After all, you have to train the model : +``` +# track test loss +test_loss = 0.0 +class_correct = list(0.0 for i in range(10)) +class_total = list(0.0 for i in range(10)) + + +model_cifar.eval() +with torch.no_grad(): + correct, total = 0, 0 + test_loss = 0.0 + + for batch in tqdm(test_loader_cifar): + x, y = batch + x, y = x.to(device), y.to(device) + + # Forward pass + y_hat = model_cifar(x) + + # Compute the loss + loss = criterion(y_hat, y) + test_loss += loss.item() + + # Compute accuracy + _, predicted = torch.max(y_hat, 1) + total += y.size(0) + correct += (predicted == y).sum().item() + + for i in range(len(y)): + label = y.data[i] + class_correct[label] += (predicted[i] == label).item() + class_total[label] += 1 + + print(f"Test loss: {test_loss / len(test_loader_cifar):.2f}") + print(f"Test accuracy: {correct / total * 100:.2f}%") + +for i in range(10): + if class_total[i] > 0: + print( + "Test Accuracy of %5s: %2d%% (%2d/%2d)" + % ( + classes[i], + 100 * class_correct[i] / class_total[i], + np.sum(class_correct[i]), + np.sum(class_total[i]), + ) + ) + else: + print("Test Accuracy of %5s: N/A (no training examples)" % (classes[i])) + +print( + "\nTest Accuracy (Overall): %2d%% (%2d/%2d)" + % ( + 100.0 * np.sum(class_correct) / np.sum(class_total), + np.sum(class_correct), + np.sum(class_total), + ) +) ``` \ No newline at end of file