diff --git a/README.md b/README.md
index 144715d30927deead5cf5aff9780b99bae70249f..259aac4b5531b8ab7336e4ce6702e2d0009281ea 100644
--- a/README.md
+++ b/README.md
@@ -58,4 +58,178 @@ classes = [
     "ship",
     "truck",
 ]
+
+
+ for batch in train_loader_cifar:
+  x,y = batch
+  print(x.shape)
+  print(y.shape)
+
+  x, y = x.to(device), y.to(device)
+  y_hat = model_cifar(x)
+  # Calculate the batch loss
+  loss = criterion(y_hat, y)
+  print(loss/128)
+  max_value, predicted_label = torch.max(y_hat, 1)
+  print(predicted_label)
+  print(y)
+  print(model_cifar)
+  break
+```
+
+Then, in order to put the ViT model on CIFAR dataset, we need to initialisate : 
+
+```
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+print(
+    "Using device: ",
+    device,
+    f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "",
+)
+
+model_cifar = ViT(
+    (3, 32, 32), n_patches=8, n_blocks=4, hidden_d=16, n_heads=2, out_d=10
+).to(device) #Input size: 3 * 32 * 32, block 4, 32 is divisible by 8
+
+N_EPOCHS = 5
+LR = 0.005
+```
+
+And then train and validate the model : 
+
+```
+from tqdm import tqdm
+
+optimizer = Adam(model_cifar.parameters(), lr=LR)
+criterion = CrossEntropyLoss()
+val_accuracies = []
+train_loss_list_cifar = []  # list to store loss to visualize
+valid_loss_min_cifar = np.Inf  # track change in validation loss
+
+
+for epoch in range(N_EPOCHS):
+
+    #on place le modèle en mode train
+    model_cifar.train()
+    train_loss = 0.0
+    for batch in tqdm(train_loader_cifar):
+       # Forward pass: compute predicted outputs by passing inputs to the model
+        x, y = batch
+        x, y = x.to(device), y.to(device)
+        y_hat = model_cifar(x)
+        # Calculate the batch loss
+        loss = criterion(y_hat, y)
+         # Clear the gradients of all optimized variables, before the backward pass
+        optimizer.zero_grad()
+        # Backward pass: compute gradient of the loss with respect to model parameters
+        loss.backward()
+        # Perform a single optimization step (parameter update)
+        optimizer.step()
+        # Update training loss
+        train_loss += loss.detach().cpu().item() / len(train_loader_cifar)
+    train_loss_list_cifar.append(train_loss)
+    print(f"Epoch {epoch + 1}/{N_EPOCHS} train loss: {train_loss:.2f}")
+
+    #A la fin du train d'une epoch, on evalue le modèle sur le jeu de validation
+    # Validate the model
+    model_cifar.eval()
+    with torch.no_grad(): #on se place en inférence
+      correct, total = 0, 0
+      valid_loss = 0.0
+      for batch in tqdm(valid_loader_cifar):
+          x, y = batch
+          x, y = x.to(device), y.to(device)
+
+          #
+          # TO DO: implement the computation of the loss and the accuracy (correct)
+          #
+
+          # Forward pass: compute predicted outputs by passing inputs to the model
+          y_hat = model_cifar(x)
+
+          # Calculate the batch loss
+          loss = criterion(y_hat, y)
+          # Update average validation loss
+          valid_loss += loss.item() / len(test_loader_cifar)
+
+          # Compute accuracy
+          max_value, predicted_label = torch.max(y_hat, 1)
+          total += y.size(0) #Iteration on the number of sample processed
+          correct += (predicted_label == y).sum().item()
+      val_accuracies.append(correct / total)
+      print(f"Validation loss: {valid_loss :.2f}")
+      print(f"Validation accuracy: {correct / total * 100:.2f}%")
+
+
+      # Save model if validation loss has decreased
+    if valid_loss <= valid_loss_min_cifar:
+        print(
+            "Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...".format(
+                valid_loss_min_cifar, valid_loss
+            )
+        )
+        torch.save(model_cifar.state_dict(), "model_cifar_transformers.pt")
+        valid_loss_min_cifar = valid_loss
+
+```
+
+After all, you have to train the model : 
+```
+# track test loss
+test_loss = 0.0
+class_correct = list(0.0 for i in range(10))
+class_total = list(0.0 for i in range(10))
+
+
+model_cifar.eval()
+with torch.no_grad():
+    correct, total = 0, 0
+    test_loss = 0.0
+
+    for batch in tqdm(test_loader_cifar):
+        x, y = batch
+        x, y = x.to(device), y.to(device)
+
+        # Forward pass
+        y_hat = model_cifar(x)
+
+        # Compute the loss
+        loss = criterion(y_hat, y)
+        test_loss += loss.item()
+
+        # Compute accuracy
+        _, predicted = torch.max(y_hat, 1)
+        total += y.size(0)
+        correct += (predicted == y).sum().item()
+
+        for i in range(len(y)):
+            label = y.data[i]
+            class_correct[label] += (predicted[i] == label).item()
+            class_total[label] += 1
+
+    print(f"Test loss: {test_loss / len(test_loader_cifar):.2f}")
+    print(f"Test accuracy: {correct / total * 100:.2f}%")
+
+for i in range(10):
+    if class_total[i] > 0:
+        print(
+            "Test Accuracy of %5s: %2d%% (%2d/%2d)"
+            % (
+                classes[i],
+                100 * class_correct[i] / class_total[i],
+                np.sum(class_correct[i]),
+                np.sum(class_total[i]),
+            )
+        )
+    else:
+        print("Test Accuracy of %5s: N/A (no training examples)" % (classes[i]))
+
+print(
+    "\nTest Accuracy (Overall): %2d%% (%2d/%2d)"
+    % (
+        100.0 * np.sum(class_correct) / np.sum(class_total),
+        np.sum(class_correct),
+        np.sum(class_total),
+    )
+)
 ```
\ No newline at end of file