From ae34a97d818d0ccf6b79c91a84a0e9ab2721919e Mon Sep 17 00:00:00 2001
From: Danjou Pierre <pierre.danjou@etu.ec-lyon.fr>
Date: Wed, 8 Jan 2025 16:52:58 +0000
Subject: [PATCH] test

---
 README.md | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 61 insertions(+)
 create mode 100644 README.md

diff --git a/README.md b/README.md
new file mode 100644
index 0000000..144715d
--- /dev/null
+++ b/README.md
@@ -0,0 +1,61 @@
+To adapt the code to apply the ViT model on CIFAR dataset : 
+
+```python
+# Load the CIFAR dataset
+
+import numpy as np
+from torchvision import datasets, transforms
+from torch.utils.data.sampler import SubsetRandomSampler
+
+# number of subprocesses to use for data loading
+num_workers = 4
+# how many samples per batch to load
+batch_size = 128
+# percentage of training set to use as validation
+valid_size = 0.2
+
+# convert data to a normalized torch.FloatTensor
+transform = transforms.Compose(
+    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
+)
+
+# choose the training and test datasets
+train_data = datasets.CIFAR10("data", train=True, download=True, transform=transform)
+test_data = datasets.CIFAR10("data", train=False, download=True, transform=transform)
+
+# obtain training indices that will be used for validation
+num_train = len(train_data)
+indices = list(range(num_train))
+np.random.shuffle(indices)
+split = int(np.floor(valid_size * num_train))
+train_idx, valid_idx = indices[split:], indices[:split]
+
+# define samplers for obtaining training and validation batches
+train_sampler = SubsetRandomSampler(train_idx)
+valid_sampler = SubsetRandomSampler(valid_idx)
+
+# prepare data loaders (combine dataset and sampler)
+train_loader_cifar = torch.utils.data.DataLoader(
+    train_data, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers
+)
+valid_loader_cifar = torch.utils.data.DataLoader(
+    train_data, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers
+)
+test_loader_cifar = torch.utils.data.DataLoader(
+    test_data, batch_size=batch_size, num_workers=num_workers
+)
+
+# specify the image classes
+classes = [
+    "airplane",
+    "automobile",
+    "bird",
+    "cat",
+    "deer",
+    "dog",
+    "frog",
+    "horse",
+    "ship",
+    "truck",
+]
+```
\ No newline at end of file
-- 
GitLab