Skip to content
Snippets Groups Projects
Select Git revision
  • 4bec3d7a41d1303b89f885d1f1f454f5e93de5e6
  • main default protected
2 results

mod_4_6-td3

Forked from Dellandrea Emmanuel / MOD_4_6-TD3
3 commits ahead of the upstream repository.
Danjou Pierre's avatar
Danjou Pierre authored
4bec3d7a
History

To adapt the code to apply the ViT model on CIFAR dataset :

# Load the CIFAR dataset

import numpy as np
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler

# number of subprocesses to use for data loading
num_workers = 4
# how many samples per batch to load
batch_size = 128
# percentage of training set to use as validation
valid_size = 0.2

# convert data to a normalized torch.FloatTensor
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

# choose the training and test datasets
train_data = datasets.CIFAR10("data", train=True, download=True, transform=transform)
test_data = datasets.CIFAR10("data", train=False, download=True, transform=transform)

# obtain training indices that will be used for validation
num_train = len(train_data)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]

# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

# prepare data loaders (combine dataset and sampler)
train_loader_cifar = torch.utils.data.DataLoader(
    train_data, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers
)
valid_loader_cifar = torch.utils.data.DataLoader(
    train_data, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers
)
test_loader_cifar = torch.utils.data.DataLoader(
    test_data, batch_size=batch_size, num_workers=num_workers
)

# specify the image classes
classes = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]


 for batch in train_loader_cifar:
  x,y = batch
  print(x.shape)
  print(y.shape)

  x, y = x.to(device), y.to(device)
  y_hat = model_cifar(x)
  # Calculate the batch loss
  loss = criterion(y_hat, y)
  print(loss/128)
  max_value, predicted_label = torch.max(y_hat, 1)
  print(predicted_label)
  print(y)
  print(model_cifar)
  break

Then, in order to put the ViT model on CIFAR dataset, we need to initialisate :

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(
    "Using device: ",
    device,
    f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "",
)

model_cifar = ViT(
    (3, 32, 32), n_patches=8, n_blocks=4, hidden_d=16, n_heads=2, out_d=10
).to(device) #Input size: 3 * 32 * 32, block 4, 32 is divisible by 8

N_EPOCHS = 5
LR = 0.005

And then train and validate the model :

from tqdm import tqdm

optimizer = Adam(model_cifar.parameters(), lr=LR)
criterion = CrossEntropyLoss()
val_accuracies = []
train_loss_list_cifar = []  # list to store loss to visualize
valid_loss_min_cifar = np.Inf  # track change in validation loss


for epoch in range(N_EPOCHS):

    #on place le modèle en mode train
    model_cifar.train()
    train_loss = 0.0
    for batch in tqdm(train_loader_cifar):
       # Forward pass: compute predicted outputs by passing inputs to the model
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model_cifar(x)
        # Calculate the batch loss
        loss = criterion(y_hat, y)
         # Clear the gradients of all optimized variables, before the backward pass
        optimizer.zero_grad()
        # Backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # Perform a single optimization step (parameter update)
        optimizer.step()
        # Update training loss
        train_loss += loss.detach().cpu().item() / len(train_loader_cifar)
    train_loss_list_cifar.append(train_loss)
    print(f"Epoch {epoch + 1}/{N_EPOCHS} train loss: {train_loss:.2f}")

    #A la fin du train d'une epoch, on evalue le modèle sur le jeu de validation
    # Validate the model
    model_cifar.eval()
    with torch.no_grad(): #on se place en inférence
      correct, total = 0, 0
      valid_loss = 0.0
      for batch in tqdm(valid_loader_cifar):
          x, y = batch
          x, y = x.to(device), y.to(device)

          #
          # TO DO: implement the computation of the loss and the accuracy (correct)
          #

          # Forward pass: compute predicted outputs by passing inputs to the model
          y_hat = model_cifar(x)

          # Calculate the batch loss
          loss = criterion(y_hat, y)
          # Update average validation loss
          valid_loss += loss.item() / len(test_loader_cifar)

          # Compute accuracy
          max_value, predicted_label = torch.max(y_hat, 1)
          total += y.size(0) #Iteration on the number of sample processed
          correct += (predicted_label == y).sum().item()
      val_accuracies.append(correct / total)
      print(f"Validation loss: {valid_loss :.2f}")
      print(f"Validation accuracy: {correct / total * 100:.2f}%")


      # Save model if validation loss has decreased
    if valid_loss <= valid_loss_min_cifar:
        print(
            "Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...".format(
                valid_loss_min_cifar, valid_loss
            )
        )
        torch.save(model_cifar.state_dict(), "model_cifar_transformers.pt")
        valid_loss_min_cifar = valid_loss

After all, you have to train the model :

# track test loss
test_loss = 0.0
class_correct = list(0.0 for i in range(10))
class_total = list(0.0 for i in range(10))


model_cifar.eval()
with torch.no_grad():
    correct, total = 0, 0
    test_loss = 0.0

    for batch in tqdm(test_loader_cifar):
        x, y = batch
        x, y = x.to(device), y.to(device)

        # Forward pass
        y_hat = model_cifar(x)

        # Compute the loss
        loss = criterion(y_hat, y)
        test_loss += loss.item()

        # Compute accuracy
        _, predicted = torch.max(y_hat, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()

        for i in range(len(y)):
            label = y.data[i]
            class_correct[label] += (predicted[i] == label).item()
            class_total[label] += 1

    print(f"Test loss: {test_loss / len(test_loader_cifar):.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")

for i in range(10):
    if class_total[i] > 0:
        print(
            "Test Accuracy of %5s: %2d%% (%2d/%2d)"
            % (
                classes[i],
                100 * class_correct[i] / class_total[i],
                np.sum(class_correct[i]),
                np.sum(class_total[i]),
            )
        )
    else:
        print("Test Accuracy of %5s: N/A (no training examples)" % (classes[i]))

print(
    "\nTest Accuracy (Overall): %2d%% (%2d/%2d)"
    % (
        100.0 * np.sum(class_correct) / np.sum(class_total),
        np.sum(class_correct),
        np.sum(class_total),
    )
)