Skip to content
Snippets Groups Projects
Commit 4bec3d7a authored by Danjou Pierre's avatar Danjou Pierre
Browse files

Update README.md

parent ae34a97d
No related branches found
No related tags found
No related merge requests found
...@@ -58,4 +58,178 @@ classes = [ ...@@ -58,4 +58,178 @@ classes = [
"ship", "ship",
"truck", "truck",
] ]
for batch in train_loader_cifar:
x,y = batch
print(x.shape)
print(y.shape)
x, y = x.to(device), y.to(device)
y_hat = model_cifar(x)
# Calculate the batch loss
loss = criterion(y_hat, y)
print(loss/128)
max_value, predicted_label = torch.max(y_hat, 1)
print(predicted_label)
print(y)
print(model_cifar)
break
```
Then, in order to put the ViT model on CIFAR dataset, we need to initialisate :
```
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(
"Using device: ",
device,
f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "",
)
model_cifar = ViT(
(3, 32, 32), n_patches=8, n_blocks=4, hidden_d=16, n_heads=2, out_d=10
).to(device) #Input size: 3 * 32 * 32, block 4, 32 is divisible by 8
N_EPOCHS = 5
LR = 0.005
```
And then train and validate the model :
```
from tqdm import tqdm
optimizer = Adam(model_cifar.parameters(), lr=LR)
criterion = CrossEntropyLoss()
val_accuracies = []
train_loss_list_cifar = [] # list to store loss to visualize
valid_loss_min_cifar = np.Inf # track change in validation loss
for epoch in range(N_EPOCHS):
#on place le modèle en mode train
model_cifar.train()
train_loss = 0.0
for batch in tqdm(train_loader_cifar):
# Forward pass: compute predicted outputs by passing inputs to the model
x, y = batch
x, y = x.to(device), y.to(device)
y_hat = model_cifar(x)
# Calculate the batch loss
loss = criterion(y_hat, y)
# Clear the gradients of all optimized variables, before the backward pass
optimizer.zero_grad()
# Backward pass: compute gradient of the loss with respect to model parameters
loss.backward()
# Perform a single optimization step (parameter update)
optimizer.step()
# Update training loss
train_loss += loss.detach().cpu().item() / len(train_loader_cifar)
train_loss_list_cifar.append(train_loss)
print(f"Epoch {epoch + 1}/{N_EPOCHS} train loss: {train_loss:.2f}")
#A la fin du train d'une epoch, on evalue le modèle sur le jeu de validation
# Validate the model
model_cifar.eval()
with torch.no_grad(): #on se place en inférence
correct, total = 0, 0
valid_loss = 0.0
for batch in tqdm(valid_loader_cifar):
x, y = batch
x, y = x.to(device), y.to(device)
#
# TO DO: implement the computation of the loss and the accuracy (correct)
#
# Forward pass: compute predicted outputs by passing inputs to the model
y_hat = model_cifar(x)
# Calculate the batch loss
loss = criterion(y_hat, y)
# Update average validation loss
valid_loss += loss.item() / len(test_loader_cifar)
# Compute accuracy
max_value, predicted_label = torch.max(y_hat, 1)
total += y.size(0) #Iteration on the number of sample processed
correct += (predicted_label == y).sum().item()
val_accuracies.append(correct / total)
print(f"Validation loss: {valid_loss :.2f}")
print(f"Validation accuracy: {correct / total * 100:.2f}%")
# Save model if validation loss has decreased
if valid_loss <= valid_loss_min_cifar:
print(
"Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...".format(
valid_loss_min_cifar, valid_loss
)
)
torch.save(model_cifar.state_dict(), "model_cifar_transformers.pt")
valid_loss_min_cifar = valid_loss
```
After all, you have to train the model :
```
# track test loss
test_loss = 0.0
class_correct = list(0.0 for i in range(10))
class_total = list(0.0 for i in range(10))
model_cifar.eval()
with torch.no_grad():
correct, total = 0, 0
test_loss = 0.0
for batch in tqdm(test_loader_cifar):
x, y = batch
x, y = x.to(device), y.to(device)
# Forward pass
y_hat = model_cifar(x)
# Compute the loss
loss = criterion(y_hat, y)
test_loss += loss.item()
# Compute accuracy
_, predicted = torch.max(y_hat, 1)
total += y.size(0)
correct += (predicted == y).sum().item()
for i in range(len(y)):
label = y.data[i]
class_correct[label] += (predicted[i] == label).item()
class_total[label] += 1
print(f"Test loss: {test_loss / len(test_loader_cifar):.2f}")
print(f"Test accuracy: {correct / total * 100:.2f}%")
for i in range(10):
if class_total[i] > 0:
print(
"Test Accuracy of %5s: %2d%% (%2d/%2d)"
% (
classes[i],
100 * class_correct[i] / class_total[i],
np.sum(class_correct[i]),
np.sum(class_total[i]),
)
)
else:
print("Test Accuracy of %5s: N/A (no training examples)" % (classes[i]))
print(
"\nTest Accuracy (Overall): %2d%% (%2d/%2d)"
% (
100.0 * np.sum(class_correct) / np.sum(class_total),
np.sum(class_correct),
np.sum(class_total),
)
)
``` ```
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment