To adapt the code to apply the ViT model on CIFAR dataset :
# Load the CIFAR dataset
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
# number of subprocesses to use for data loading
num_workers = 4
# how many samples per batch to load
batch_size = 128
# percentage of training set to use as validation
valid_size = 0.2
# convert data to a normalized torch.FloatTensor
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
# choose the training and test datasets
train_data = datasets.CIFAR10("data", train=True, download=True, transform=transform)
test_data = datasets.CIFAR10("data", train=False, download=True, transform=transform)
# obtain training indices that will be used for validation
num_train = len(train_data)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
# prepare data loaders (combine dataset and sampler)
train_loader_cifar = torch.utils.data.DataLoader(
train_data, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers
)
valid_loader_cifar = torch.utils.data.DataLoader(
train_data, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers
)
test_loader_cifar = torch.utils.data.DataLoader(
test_data, batch_size=batch_size, num_workers=num_workers
)
# specify the image classes
classes = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
]
for batch in train_loader_cifar:
x,y = batch
print(x.shape)
print(y.shape)
x, y = x.to(device), y.to(device)
y_hat = model_cifar(x)
# Calculate the batch loss
loss = criterion(y_hat, y)
print(loss/128)
max_value, predicted_label = torch.max(y_hat, 1)
print(predicted_label)
print(y)
print(model_cifar)
break
Then, in order to put the ViT model on CIFAR dataset, we need to initialisate :
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(
"Using device: ",
device,
f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "",
)
model_cifar = ViT(
(3, 32, 32), n_patches=8, n_blocks=4, hidden_d=16, n_heads=2, out_d=10
).to(device) #Input size: 3 * 32 * 32, block 4, 32 is divisible by 8
N_EPOCHS = 5
LR = 0.005
And then train and validate the model :
from tqdm import tqdm
optimizer = Adam(model_cifar.parameters(), lr=LR)
criterion = CrossEntropyLoss()
val_accuracies = []
train_loss_list_cifar = [] # list to store loss to visualize
valid_loss_min_cifar = np.Inf # track change in validation loss
for epoch in range(N_EPOCHS):
#on place le modèle en mode train
model_cifar.train()
train_loss = 0.0
for batch in tqdm(train_loader_cifar):
# Forward pass: compute predicted outputs by passing inputs to the model
x, y = batch
x, y = x.to(device), y.to(device)
y_hat = model_cifar(x)
# Calculate the batch loss
loss = criterion(y_hat, y)
# Clear the gradients of all optimized variables, before the backward pass
optimizer.zero_grad()
# Backward pass: compute gradient of the loss with respect to model parameters
loss.backward()
# Perform a single optimization step (parameter update)
optimizer.step()
# Update training loss
train_loss += loss.detach().cpu().item() / len(train_loader_cifar)
train_loss_list_cifar.append(train_loss)
print(f"Epoch {epoch + 1}/{N_EPOCHS} train loss: {train_loss:.2f}")
#A la fin du train d'une epoch, on evalue le modèle sur le jeu de validation
# Validate the model
model_cifar.eval()
with torch.no_grad(): #on se place en inférence
correct, total = 0, 0
valid_loss = 0.0
for batch in tqdm(valid_loader_cifar):
x, y = batch
x, y = x.to(device), y.to(device)
#
# TO DO: implement the computation of the loss and the accuracy (correct)
#
# Forward pass: compute predicted outputs by passing inputs to the model
y_hat = model_cifar(x)
# Calculate the batch loss
loss = criterion(y_hat, y)
# Update average validation loss
valid_loss += loss.item() / len(test_loader_cifar)
# Compute accuracy
max_value, predicted_label = torch.max(y_hat, 1)
total += y.size(0) #Iteration on the number of sample processed
correct += (predicted_label == y).sum().item()
val_accuracies.append(correct / total)
print(f"Validation loss: {valid_loss :.2f}")
print(f"Validation accuracy: {correct / total * 100:.2f}%")
# Save model if validation loss has decreased
if valid_loss <= valid_loss_min_cifar:
print(
"Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...".format(
valid_loss_min_cifar, valid_loss
)
)
torch.save(model_cifar.state_dict(), "model_cifar_transformers.pt")
valid_loss_min_cifar = valid_loss
After all, you have to train the model :
# track test loss
test_loss = 0.0
class_correct = list(0.0 for i in range(10))
class_total = list(0.0 for i in range(10))
model_cifar.eval()
with torch.no_grad():
correct, total = 0, 0
test_loss = 0.0
for batch in tqdm(test_loader_cifar):
x, y = batch
x, y = x.to(device), y.to(device)
# Forward pass
y_hat = model_cifar(x)
# Compute the loss
loss = criterion(y_hat, y)
test_loss += loss.item()
# Compute accuracy
_, predicted = torch.max(y_hat, 1)
total += y.size(0)
correct += (predicted == y).sum().item()
for i in range(len(y)):
label = y.data[i]
class_correct[label] += (predicted[i] == label).item()
class_total[label] += 1
print(f"Test loss: {test_loss / len(test_loader_cifar):.2f}")
print(f"Test accuracy: {correct / total * 100:.2f}%")
for i in range(10):
if class_total[i] > 0:
print(
"Test Accuracy of %5s: %2d%% (%2d/%2d)"
% (
classes[i],
100 * class_correct[i] / class_total[i],
np.sum(class_correct[i]),
np.sum(class_total[i]),
)
)
else:
print("Test Accuracy of %5s: N/A (no training examples)" % (classes[i]))
print(
"\nTest Accuracy (Overall): %2d%% (%2d/%2d)"
% (
100.0 * np.sum(class_correct) / np.sum(class_total),
np.sum(class_correct),
np.sum(class_total),
)
)