From ae34a97d818d0ccf6b79c91a84a0e9ab2721919e Mon Sep 17 00:00:00 2001 From: Danjou Pierre <pierre.danjou@etu.ec-lyon.fr> Date: Wed, 8 Jan 2025 16:52:58 +0000 Subject: [PATCH] test --- README.md | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..144715d --- /dev/null +++ b/README.md @@ -0,0 +1,61 @@ +To adapt the code to apply the ViT model on CIFAR dataset : + +```python +# Load the CIFAR dataset + +import numpy as np +from torchvision import datasets, transforms +from torch.utils.data.sampler import SubsetRandomSampler + +# number of subprocesses to use for data loading +num_workers = 4 +# how many samples per batch to load +batch_size = 128 +# percentage of training set to use as validation +valid_size = 0.2 + +# convert data to a normalized torch.FloatTensor +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] +) + +# choose the training and test datasets +train_data = datasets.CIFAR10("data", train=True, download=True, transform=transform) +test_data = datasets.CIFAR10("data", train=False, download=True, transform=transform) + +# obtain training indices that will be used for validation +num_train = len(train_data) +indices = list(range(num_train)) +np.random.shuffle(indices) +split = int(np.floor(valid_size * num_train)) +train_idx, valid_idx = indices[split:], indices[:split] + +# define samplers for obtaining training and validation batches +train_sampler = SubsetRandomSampler(train_idx) +valid_sampler = SubsetRandomSampler(valid_idx) + +# prepare data loaders (combine dataset and sampler) +train_loader_cifar = torch.utils.data.DataLoader( + train_data, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers +) +valid_loader_cifar = torch.utils.data.DataLoader( + train_data, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers +) +test_loader_cifar = torch.utils.data.DataLoader( + test_data, batch_size=batch_size, num_workers=num_workers +) + +# specify the image classes +classes = [ + "airplane", + "automobile", + "bird", + "cat", + "deer", + "dog", + "frog", + "horse", + "ship", + "truck", +] +``` \ No newline at end of file -- GitLab