diff --git a/TD2 Deep Learning.ipynb b/TD2 Deep Learning.ipynb
index 2ecfce959ae6b947b633a758433f9bea0bf6992e..de44c488965cb322cfe8f5c028ef905fcf7c66e0 100644
--- a/TD2 Deep Learning.ipynb	
+++ b/TD2 Deep Learning.ipynb	
@@ -31,14 +31,53 @@
     "Install and test PyTorch from  https://pytorch.org/get-started/locally."
    ]
   },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "I am using a personal remote jupyter server that is running on a pc with two gpus."
+   ]
+  },
   {
    "cell_type": "code",
-   "execution_count": null,
-   "id": "330a42f5",
+   "execution_count": 1,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Wed Nov 22 20:09:07 2023       \n",
+      "+---------------------------------------------------------------------------------------+\n",
+      "| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |\n",
+      "|-----------------------------------------+----------------------+----------------------+\n",
+      "| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
+      "| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |\n",
+      "|                                         |                      |               MIG M. |\n",
+      "|=========================================+======================+======================|\n",
+      "|   0  NVIDIA GeForce RTX 3090        Off | 00000000:01:00.0 Off |                  N/A |\n",
+      "|  0%   48C    P8              19W / 420W |    269MiB / 24576MiB |      0%      Default |\n",
+      "|                                         |                      |                  N/A |\n",
+      "+-----------------------------------------+----------------------+----------------------+\n",
+      "|   1  NVIDIA GeForce RTX 3070        Off | 00000000:07:00.0 Off |                  N/A |\n",
+      "|  0%   50C    P8              15W / 220W |     10MiB /  8192MiB |      0%      Default |\n",
+      "|                                         |                      |                  N/A |\n",
+      "+-----------------------------------------+----------------------+----------------------+\n",
+      "                                                                                         \n",
+      "+---------------------------------------------------------------------------------------+\n",
+      "| Processes:                                                                            |\n",
+      "|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |\n",
+      "|        ID   ID                                                             Usage      |\n",
+      "|=======================================================================================|\n",
+      "|    0   N/A  N/A      1670      G   /usr/lib/xorg/Xorg                          247MiB |\n",
+      "|    0   N/A  N/A      1841      G   /usr/bin/gnome-shell                         11MiB |\n",
+      "|    1   N/A  N/A      1670      G   /usr/lib/xorg/Xorg                            4MiB |\n",
+      "+---------------------------------------------------------------------------------------+\n"
+     ]
+    }
+   ],
    "source": [
-    "%pip install torch torchvision"
+    "!nvidia-smi"
    ]
   },
   {
@@ -52,10 +91,72 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 2,
    "id": "b1950f0a",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "tensor([[ 1.8855e-03,  2.9108e-01,  1.0010e+00, -1.7448e+00,  1.8745e+00,\n",
+      "          3.0039e-01, -1.0610e+00, -4.2353e-01,  1.5882e+00,  4.7778e-01],\n",
+      "        [ 6.9629e-01, -2.1862e-01, -1.1275e+00, -2.2043e-01, -9.0295e-01,\n",
+      "         -1.9357e+00, -9.8882e-01,  9.4581e-01, -1.1930e+00,  9.8106e-01],\n",
+      "        [-3.3844e-01, -1.0876e-01, -4.0938e-01,  2.8881e-01, -1.9949e-01,\n",
+      "         -5.8048e-01, -7.5410e-01, -1.2992e+00, -4.5162e-01,  9.0695e-01],\n",
+      "        [ 9.0743e-01,  1.8967e+00,  2.6043e+00,  6.2713e-01,  8.9104e-01,\n",
+      "          4.2479e-01, -1.1447e+00, -1.5549e+00, -1.1788e+00, -3.2302e-01],\n",
+      "        [-1.1279e+00,  7.4629e-02, -9.2078e-01, -7.8896e-01, -2.4876e-01,\n",
+      "          1.5658e-01, -2.8966e-01, -1.0835e+00,  1.1235e+00, -6.8270e-01],\n",
+      "        [ 3.4747e-01, -1.4344e+00,  7.0211e-01,  1.9160e+00, -1.5627e+00,\n",
+      "          8.7415e-02,  7.2565e-01, -2.4600e-02, -2.1433e-01, -4.1230e-01],\n",
+      "        [ 1.6323e-01, -6.4762e-01,  7.1466e-02, -4.9402e-01,  4.6785e-01,\n",
+      "          1.2793e+00,  1.7295e+00,  1.6134e-01,  1.1057e+00, -9.2903e-01],\n",
+      "        [ 2.7565e-01,  1.1653e+00, -1.8649e+00, -6.0089e-01,  1.4255e-01,\n",
+      "          5.1984e-01,  1.4124e+00,  4.3731e-01, -7.1495e-01,  4.4668e-01],\n",
+      "        [-2.0284e+00,  3.2644e-02, -1.0220e+00, -7.5502e-01,  1.4939e+00,\n",
+      "          2.1324e+00, -9.7155e-02,  4.4492e-01,  2.0190e+00, -1.4172e+00],\n",
+      "        [ 7.7286e-01, -1.8415e-01, -1.7536e-01, -5.6652e-01, -1.4285e+00,\n",
+      "          1.0795e+00, -3.8429e-01, -1.8018e+00, -9.7339e-02,  7.7694e-01],\n",
+      "        [-8.3219e-01, -3.1330e-01,  5.0993e-01,  4.6975e-01, -2.6981e-01,\n",
+      "         -1.7035e-01,  8.6431e-01,  5.9563e-01, -1.7859e-01,  1.8930e+00],\n",
+      "        [-4.5472e-01,  1.7444e+00,  4.8612e-01, -5.4073e-01,  1.4415e+00,\n",
+      "          9.6243e-01,  6.3097e-01, -6.6990e-01,  1.6233e+00, -1.1163e+00],\n",
+      "        [ 2.8865e-01, -8.5031e-01,  8.4932e-01, -1.6480e-01, -9.4282e-01,\n",
+      "          1.9159e+00, -4.7449e-01,  1.0314e-01,  4.7082e-01, -1.5315e+00],\n",
+      "        [ 6.2820e-01, -9.2092e-01, -6.6795e-01,  5.1397e-01,  3.8067e-01,\n",
+      "         -1.4796e-02, -2.6149e-01, -1.2254e+00, -5.8194e-01, -1.5822e+00]])\n",
+      "AlexNet(\n",
+      "  (features): Sequential(\n",
+      "    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n",
+      "    (1): ReLU(inplace=True)\n",
+      "    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+      "    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
+      "    (4): ReLU(inplace=True)\n",
+      "    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+      "    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+      "    (7): ReLU(inplace=True)\n",
+      "    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+      "    (9): ReLU(inplace=True)\n",
+      "    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+      "    (11): ReLU(inplace=True)\n",
+      "    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+      "  )\n",
+      "  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n",
+      "  (classifier): Sequential(\n",
+      "    (0): Dropout(p=0.5, inplace=False)\n",
+      "    (1): Linear(in_features=9216, out_features=4096, bias=True)\n",
+      "    (2): ReLU(inplace=True)\n",
+      "    (3): Dropout(p=0.5, inplace=False)\n",
+      "    (4): Linear(in_features=4096, out_features=4096, bias=True)\n",
+      "    (5): ReLU(inplace=True)\n",
+      "    (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
+      "  )\n",
+      ")\n"
+     ]
+    }
+   ],
    "source": [
     "import torch\n",
     "\n",
@@ -95,12 +196,34 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
-   "id": "6e18f2fd",
+   "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
     "import torch\n",
+    "import numpy as np\n",
+    "from torchvision import datasets, transforms\n",
+    "from torch.utils.data.sampler import SubsetRandomSampler\n",
+    "import torch.optim as optim\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "id": "6e18f2fd",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "CUDA is available!  Training on GPU ...\n"
+     ]
+    }
+   ],
+   "source": [
     "\n",
     "# check if CUDA is available\n",
     "train_on_gpu = torch.cuda.is_available()\n",
@@ -121,14 +244,20 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 5,
    "id": "462666a2",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Files already downloaded and verified\n",
+      "Files already downloaded and verified\n"
+     ]
+    }
+   ],
    "source": [
-    "import numpy as np\n",
-    "from torchvision import datasets, transforms\n",
-    "from torch.utils.data.sampler import SubsetRandomSampler\n",
     "\n",
     "# number of subprocesses to use for data loading\n",
     "num_workers = 0\n",
@@ -193,17 +322,28 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 9,
    "id": "317bf070",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Net(\n",
+      "  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))\n",
+      "  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+      "  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
+      "  (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
+      "  (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
+      "  (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
+      ")\n"
+     ]
+    }
+   ],
    "source": [
-    "import torch.nn as nn\n",
-    "import torch.nn.functional as F\n",
-    "\n",
     "# define the CNN architecture\n",
     "\n",
-    "\n",
     "class Net(nn.Module):\n",
     "    def __init__(self):\n",
     "        super(Net, self).__init__()\n",
@@ -217,6 +357,7 @@
     "    def forward(self, x):\n",
     "        x = self.pool(F.relu(self.conv1(x)))\n",
     "        x = self.pool(F.relu(self.conv2(x)))\n",
+    "        # print(f'x.shape = {x.shape}')\n",
     "        x = x.view(-1, 16 * 5 * 5)\n",
     "        x = F.relu(self.fc1(x))\n",
     "        x = F.relu(self.fc2(x))\n",
@@ -237,23 +378,88 @@
    "id": "a2dc4974",
    "metadata": {},
    "source": [
-    "Loss function and training using SGD (Stochastic Gradient Descent) optimizer"
+    "Loss function and training using SGD (Stochastic Gradient Descent) optimizer\n",
+    "> I added the running validation loss, to see if visually if overfitting occurs. <br>"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 10,
    "id": "4b53f229",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch: 0 \tTraining Loss: 42.195077 \tValidation Loss: 37.653781\n",
+      "Validation loss decreased (inf --> 37.653781).  Saving model ...\n",
+      "Epoch: 1 \tTraining Loss: 34.255484 \tValidation Loss: 31.326523\n",
+      "Validation loss decreased (37.653781 --> 31.326523).  Saving model ...\n",
+      "Epoch: 2 \tTraining Loss: 29.847037 \tValidation Loss: 29.189521\n",
+      "Validation loss decreased (31.326523 --> 29.189521).  Saving model ...\n",
+      "Epoch: 3 \tTraining Loss: 27.616926 \tValidation Loss: 27.410780\n",
+      "Validation loss decreased (29.189521 --> 27.410780).  Saving model ...\n",
+      "Epoch: 4 \tTraining Loss: 26.090010 \tValidation Loss: 26.955581\n",
+      "Validation loss decreased (27.410780 --> 26.955581).  Saving model ...\n",
+      "Epoch: 5 \tTraining Loss: 24.904551 \tValidation Loss: 25.293382\n",
+      "Validation loss decreased (26.955581 --> 25.293382).  Saving model ...\n",
+      "Epoch: 6 \tTraining Loss: 23.761944 \tValidation Loss: 24.479699\n",
+      "Validation loss decreased (25.293382 --> 24.479699).  Saving model ...\n",
+      "Epoch: 7 \tTraining Loss: 22.853249 \tValidation Loss: 24.050349\n",
+      "Validation loss decreased (24.479699 --> 24.050349).  Saving model ...\n",
+      "Epoch: 8 \tTraining Loss: 21.941920 \tValidation Loss: 23.283681\n",
+      "Validation loss decreased (24.050349 --> 23.283681).  Saving model ...\n",
+      "Epoch: 9 \tTraining Loss: 21.182262 \tValidation Loss: 23.741841\n",
+      "Epoch: 10 \tTraining Loss: 20.406208 \tValidation Loss: 22.520819\n",
+      "Validation loss decreased (23.283681 --> 22.520819).  Saving model ...\n",
+      "Epoch: 11 \tTraining Loss: 19.685864 \tValidation Loss: 22.102845\n",
+      "Validation loss decreased (22.520819 --> 22.102845).  Saving model ...\n",
+      "Epoch: 12 \tTraining Loss: 19.020183 \tValidation Loss: 21.780847\n",
+      "Validation loss decreased (22.102845 --> 21.780847).  Saving model ...\n",
+      "Epoch: 13 \tTraining Loss: 18.338785 \tValidation Loss: 22.500668\n",
+      "Epoch: 14 \tTraining Loss: 17.766254 \tValidation Loss: 22.892189\n",
+      "Epoch: 15 \tTraining Loss: 17.163492 \tValidation Loss: 21.602836\n",
+      "Validation loss decreased (21.780847 --> 21.602836).  Saving model ...\n",
+      "Epoch: 16 \tTraining Loss: 16.566336 \tValidation Loss: 21.696428\n",
+      "Epoch: 17 \tTraining Loss: 16.004122 \tValidation Loss: 22.403157\n",
+      "Epoch: 18 \tTraining Loss: 15.494520 \tValidation Loss: 22.072053\n",
+      "Epoch: 19 \tTraining Loss: 15.033227 \tValidation Loss: 21.885703\n",
+      "Epoch: 20 \tTraining Loss: 14.516863 \tValidation Loss: 22.182539\n",
+      "Epoch: 21 \tTraining Loss: 14.037906 \tValidation Loss: 22.538788\n"
+     ]
+    },
+    {
+     "ename": "KeyboardInterrupt",
+     "evalue": "",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
+      "\u001b[1;32m/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2 Deep Learning.ipynb Cell 17\u001b[0m line \u001b[0;36m1\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2%20Deep%20Learning.ipynb#X23sZmlsZQ%3D%3D?line=13'>14</a>\u001b[0m \u001b[39m# Train the model\u001b[39;00m\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2%20Deep%20Learning.ipynb#X23sZmlsZQ%3D%3D?line=14'>15</a>\u001b[0m model\u001b[39m.\u001b[39mtrain()\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2%20Deep%20Learning.ipynb#X23sZmlsZQ%3D%3D?line=15'>16</a>\u001b[0m \u001b[39mfor\u001b[39;00m data, target \u001b[39min\u001b[39;00m train_loader:\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2%20Deep%20Learning.ipynb#X23sZmlsZQ%3D%3D?line=16'>17</a>\u001b[0m     \u001b[39m# Move tensors to GPU if CUDA is available\u001b[39;00m\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2%20Deep%20Learning.ipynb#X23sZmlsZQ%3D%3D?line=17'>18</a>\u001b[0m     \u001b[39mif\u001b[39;00m train_on_gpu:\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2%20Deep%20Learning.ipynb#X23sZmlsZQ%3D%3D?line=18'>19</a>\u001b[0m         data, target \u001b[39m=\u001b[39m data\u001b[39m.\u001b[39mcuda(), target\u001b[39m.\u001b[39mcuda()\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:630\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    627\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m    628\u001b[0m     \u001b[39m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m    629\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset()  \u001b[39m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 630\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[1;32m    631\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m    632\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m    633\u001b[0m         \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m    634\u001b[0m         \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:674\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    672\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[1;32m    673\u001b[0m     index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index()  \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 674\u001b[0m     data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index)  \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m    675\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[1;32m    676\u001b[0m         data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory_device)\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m     49\u001b[0m         data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset\u001b[39m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m     50\u001b[0m     \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m         data \u001b[39m=\u001b[39m [\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[idx] \u001b[39mfor\u001b[39;00m idx \u001b[39min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m     52\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m     53\u001b[0m     data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     49\u001b[0m         data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset\u001b[39m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m     50\u001b[0m     \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m         data \u001b[39m=\u001b[39m [\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[idx] \u001b[39mfor\u001b[39;00m idx \u001b[39min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m     52\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m     53\u001b[0m     data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torchvision/datasets/cifar.py:118\u001b[0m, in \u001b[0;36mCIFAR10.__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m    115\u001b[0m img \u001b[39m=\u001b[39m Image\u001b[39m.\u001b[39mfromarray(img)\n\u001b[1;32m    117\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtransform \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 118\u001b[0m     img \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtransform(img)\n\u001b[1;32m    120\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtarget_transform \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m    121\u001b[0m     target \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtarget_transform(target)\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torchvision/transforms/transforms.py:95\u001b[0m, in \u001b[0;36mCompose.__call__\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m     93\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, img):\n\u001b[1;32m     94\u001b[0m     \u001b[39mfor\u001b[39;00m t \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtransforms:\n\u001b[0;32m---> 95\u001b[0m         img \u001b[39m=\u001b[39m t(img)\n\u001b[1;32m     96\u001b[0m     \u001b[39mreturn\u001b[39;00m img\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)  \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torchvision/transforms/transforms.py:277\u001b[0m, in \u001b[0;36mNormalize.forward\u001b[0;34m(self, tensor)\u001b[0m\n\u001b[1;32m    269\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, tensor: Tensor) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tensor:\n\u001b[1;32m    270\u001b[0m \u001b[39m    \u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m    271\u001b[0m \u001b[39m    Args:\u001b[39;00m\n\u001b[1;32m    272\u001b[0m \u001b[39m        tensor (Tensor): Tensor image to be normalized.\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    275\u001b[0m \u001b[39m        Tensor: Normalized Tensor image.\u001b[39;00m\n\u001b[1;32m    276\u001b[0m \u001b[39m    \"\"\"\u001b[39;00m\n\u001b[0;32m--> 277\u001b[0m     \u001b[39mreturn\u001b[39;00m F\u001b[39m.\u001b[39;49mnormalize(tensor, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmean, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstd, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minplace)\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torchvision/transforms/functional.py:363\u001b[0m, in \u001b[0;36mnormalize\u001b[0;34m(tensor, mean, std, inplace)\u001b[0m\n\u001b[1;32m    360\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(tensor, torch\u001b[39m.\u001b[39mTensor):\n\u001b[1;32m    361\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mimg should be Tensor Image. Got \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mtype\u001b[39m(tensor)\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 363\u001b[0m \u001b[39mreturn\u001b[39;00m F_t\u001b[39m.\u001b[39;49mnormalize(tensor, mean\u001b[39m=\u001b[39;49mmean, std\u001b[39m=\u001b[39;49mstd, inplace\u001b[39m=\u001b[39;49minplace)\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torchvision/transforms/_functional_tensor.py:920\u001b[0m, in \u001b[0;36mnormalize\u001b[0;34m(tensor, mean, std, inplace)\u001b[0m\n\u001b[1;32m    917\u001b[0m     tensor \u001b[39m=\u001b[39m tensor\u001b[39m.\u001b[39mclone()\n\u001b[1;32m    919\u001b[0m dtype \u001b[39m=\u001b[39m tensor\u001b[39m.\u001b[39mdtype\n\u001b[0;32m--> 920\u001b[0m mean \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mas_tensor(mean, dtype\u001b[39m=\u001b[39;49mdtype, device\u001b[39m=\u001b[39;49mtensor\u001b[39m.\u001b[39;49mdevice)\n\u001b[1;32m    921\u001b[0m std \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mas_tensor(std, dtype\u001b[39m=\u001b[39mdtype, device\u001b[39m=\u001b[39mtensor\u001b[39m.\u001b[39mdevice)\n\u001b[1;32m    922\u001b[0m \u001b[39mif\u001b[39;00m (std \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m)\u001b[39m.\u001b[39many():\n",
+      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+     ]
+    }
+   ],
    "source": [
-    "import torch.optim as optim\n",
     "\n",
     "criterion = nn.CrossEntropyLoss()  # specify loss function\n",
     "optimizer = optim.SGD(model.parameters(), lr=0.01)  # specify optimizer\n",
     "\n",
     "n_epochs = 30  # number of epochs to train the model\n",
     "train_loss_list = []  # list to store loss to visualize\n",
+    "running_validation_loss = []\n",
     "valid_loss_min = np.Inf  # track change in validation loss\n",
     "\n",
     "for epoch in range(n_epochs):\n",
@@ -297,6 +503,7 @@
     "    train_loss = train_loss / len(train_loader)\n",
     "    valid_loss = valid_loss / len(valid_loader)\n",
     "    train_loss_list.append(train_loss)\n",
+    "    running_validation_loss.append(valid_loss)\n",
     "\n",
     "    # Print training/validation statistics\n",
     "    print(\n",
@@ -321,22 +528,36 @@
    "id": "13e1df74",
    "metadata": {},
    "source": [
-    "Does overfit occur? If so, do an early stopping."
+    "Does overfit occur? If so, do an early stopping.\n",
+    "> I did an early stopping around epoch 21."
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 12,
    "id": "d39df818",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
    "source": [
     "import matplotlib.pyplot as plt\n",
     "\n",
-    "plt.plot(range(n_epochs), train_loss_list)\n",
+    "plt.plot([i for i in range(len(train_loss_list))], train_loss_list)\n",
+    "plt.plot([i for i in range(len(running_validation_loss))], running_validation_loss)\n",
     "plt.xlabel(\"Epoch\")\n",
     "plt.ylabel(\"Loss\")\n",
     "plt.title(\"Performance of Model 1\")\n",
+    "plt.legend(['Train loss', 'Validation loss'])\n",
     "plt.show()"
    ]
   },
@@ -350,10 +571,31 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 13,
    "id": "e93efdfc",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Test Loss: 21.142969\n",
+      "\n",
+      "Test Accuracy of airplane: 63% (635/1000)\n",
+      "Test Accuracy of automobile: 77% (779/1000)\n",
+      "Test Accuracy of  bird: 45% (451/1000)\n",
+      "Test Accuracy of   cat: 52% (528/1000)\n",
+      "Test Accuracy of  deer: 62% (624/1000)\n",
+      "Test Accuracy of   dog: 48% (482/1000)\n",
+      "Test Accuracy of  frog: 72% (726/1000)\n",
+      "Test Accuracy of horse: 70% (708/1000)\n",
+      "Test Accuracy of  ship: 76% (768/1000)\n",
+      "Test Accuracy of truck: 67% (672/1000)\n",
+      "\n",
+      "Test Accuracy (Overall): 63% (6373/10000)\n"
+     ]
+    }
+   ],
    "source": [
     "model.load_state_dict(torch.load(\"./model_cifar.pt\"))\n",
     "\n",
@@ -371,6 +613,7 @@
     "    # forward pass: compute predicted outputs by passing inputs to the model\n",
     "    output = model(data)\n",
     "    # calculate the batch loss\n",
+    "    print(f'data = {data}\\n target = {target}')\n",
     "    loss = criterion(output, target)\n",
     "    # update test loss\n",
     "    test_loss += loss.item() * data.size(0)\n",
@@ -434,6 +677,304 @@
     "Compare the results obtained with this new network to those obtained previously."
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "CustomNet(\n",
+      "  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+      "  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+      "  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+      "  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+      "  (fc1): Linear(in_features=1024, out_features=520, bias=True)\n",
+      "  (fc2): Linear(in_features=520, out_features=64, bias=True)\n",
+      "  (fc3): Linear(in_features=64, out_features=10, bias=True)\n",
+      ")\n"
+     ]
+    }
+   ],
+   "source": [
+    "class CustomNet(nn.Module):\n",
+    "    def __init__(self):\n",
+    "        super(CustomNet, self).__init__()\n",
+    "        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)\n",
+    "        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)\n",
+    "        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)\n",
+    "        \n",
+    "        self.pool = nn.MaxPool2d(kernel_size=2)\n",
+    "\n",
+    "        self.fc1 = nn.Linear(64 * 4 * 4, 520)\n",
+    "        self.fc2 = nn.Linear(520, 64)\n",
+    "        self.fc3 = nn.Linear(64, 10)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        x = self.pool(F.relu(self.conv1(x)))\n",
+    "        x = self.pool(F.relu(self.conv2(x)))\n",
+    "        x = self.pool(F.relu(self.conv3(x)))\n",
+    "\n",
+    "        x = x.view(-1, 64 * 4 * 4)\n",
+    "\n",
+    "        # print(f'before first fully connected: x.shape = {x.shape}')\n",
+    "        x = F.relu(self.fc1(x))\n",
+    "        # print(f'after first fully connected, x.shape = {x.shape}')\n",
+    "        x = F.relu(self.fc2(x))\n",
+    "        x = self.fc3(x)\n",
+    "        return x\n",
+    "\n",
+    "\n",
+    "# create a complete CNN\n",
+    "custom_model = CustomNet()\n",
+    "print(custom_model)\n",
+    "# move tensors to GPU if CUDA is available\n",
+    "if train_on_gpu:\n",
+    "    custom_model.cuda()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Training of the new version of the neural network"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch: 0 \tTraining Loss: 1.134157 \tValidation Loss: 35.816960\n",
+      "Validation loss decreased (inf --> 35.816960).  Saving model ...\n",
+      "Epoch: 1 \tTraining Loss: 1.088380 \tValidation Loss: 37.159479\n",
+      "Epoch: 2 \tTraining Loss: 0.910855 \tValidation Loss: 37.244466\n",
+      "Epoch: 3 \tTraining Loss: 0.457045 \tValidation Loss: 39.297563\n",
+      "Epoch: 4 \tTraining Loss: 0.567809 \tValidation Loss: 39.350203\n",
+      "Epoch: 5 \tTraining Loss: 0.581826 \tValidation Loss: 39.824758\n",
+      "Epoch: 6 \tTraining Loss: 0.475355 \tValidation Loss: 41.428177\n",
+      "Epoch: 7 \tTraining Loss: 0.446757 \tValidation Loss: 39.261758\n",
+      "Epoch: 8 \tTraining Loss: 0.179121 \tValidation Loss: 42.353380\n",
+      "Epoch: 9 \tTraining Loss: 0.107020 \tValidation Loss: 44.952167\n",
+      "Epoch: 10 \tTraining Loss: 0.038190 \tValidation Loss: 42.309538\n",
+      "Epoch: 11 \tTraining Loss: 0.006820 \tValidation Loss: 43.020387\n",
+      "Epoch: 12 \tTraining Loss: 0.004769 \tValidation Loss: 43.565180\n",
+      "Epoch: 13 \tTraining Loss: 0.003950 \tValidation Loss: 44.026827\n",
+      "Epoch: 14 \tTraining Loss: 0.003434 \tValidation Loss: 44.447328\n",
+      "Epoch: 15 \tTraining Loss: 0.003062 \tValidation Loss: 44.806831\n",
+      "Epoch: 16 \tTraining Loss: 0.002774 \tValidation Loss: 45.128416\n",
+      "Epoch: 17 \tTraining Loss: 0.002539 \tValidation Loss: 45.398260\n",
+      "Epoch: 18 \tTraining Loss: 0.002356 \tValidation Loss: 45.667429\n",
+      "Epoch: 19 \tTraining Loss: 0.002191 \tValidation Loss: 45.895782\n",
+      "Epoch: 20 \tTraining Loss: 0.002060 \tValidation Loss: 46.140433\n",
+      "Epoch: 21 \tTraining Loss: 0.001937 \tValidation Loss: 46.355325\n",
+      "Epoch: 22 \tTraining Loss: 0.001834 \tValidation Loss: 46.542832\n",
+      "Epoch: 23 \tTraining Loss: 0.001739 \tValidation Loss: 46.742928\n",
+      "Epoch: 24 \tTraining Loss: 0.001656 \tValidation Loss: 46.915519\n",
+      "Epoch: 25 \tTraining Loss: 0.001580 \tValidation Loss: 47.089529\n",
+      "Epoch: 26 \tTraining Loss: 0.001512 \tValidation Loss: 47.235664\n",
+      "Epoch: 27 \tTraining Loss: 0.001450 \tValidation Loss: 47.399466\n",
+      "Epoch: 28 \tTraining Loss: 0.001393 \tValidation Loss: 47.562623\n",
+      "Epoch: 29 \tTraining Loss: 0.001340 \tValidation Loss: 47.701985\n"
+     ]
+    }
+   ],
+   "source": [
+    "criterion = nn.CrossEntropyLoss()  # specify loss function\n",
+    "optimizer = optim.SGD(custom_model.parameters(), lr=0.01)  # specify optimizer\n",
+    "\n",
+    "n_epochs = 30  # number of epochs to train the custom_model\n",
+    "train_loss_list = []  # list to store loss to visualize\n",
+    "valid_loss_min = np.Inf  # track change in validation loss\n",
+    "\n",
+    "for epoch in range(n_epochs):\n",
+    "    # Keep track of training and validation loss\n",
+    "    train_loss = 0.0\n",
+    "    valid_loss = 0.0\n",
+    "\n",
+    "    # Train the custom_model\n",
+    "    custom_model.train()\n",
+    "    for data, target in train_loader:\n",
+    "        # Move tensors to GPU if CUDA is available\n",
+    "        if train_on_gpu:\n",
+    "            data, target = data.cuda(), target.cuda()\n",
+    "        # Clear the gradients of all optimized variables\n",
+    "        optimizer.zero_grad()\n",
+    "        # Forward pass: compute predicted outputs by passing inputs to the custom_model\n",
+    "        output = custom_model(data)\n",
+    "        # Calculate the batch loss\n",
+    "        loss = criterion(output, target)\n",
+    "        # Backward pass: compute gradient of the loss with respect to custom_model parameters\n",
+    "        loss.backward()\n",
+    "        # Perform a single optimization step (parameter update)\n",
+    "        optimizer.step()\n",
+    "        # Update training loss\n",
+    "        train_loss += loss.item() * data.size(0)\n",
+    "\n",
+    "    # Validate the custom_model\n",
+    "    custom_model.eval()\n",
+    "    for data, target in valid_loader:\n",
+    "        # Move tensors to GPU if CUDA is available\n",
+    "        if train_on_gpu:\n",
+    "            data, target = data.cuda(), target.cuda()\n",
+    "        # Forward pass: compute predicted outputs by passing inputs to the custom_model\n",
+    "        output = custom_model(data)\n",
+    "        # Calculate the batch loss\n",
+    "        loss = criterion(output, target)\n",
+    "        # Update average validation loss\n",
+    "        valid_loss += loss.item() * data.size(0)\n",
+    "\n",
+    "    # Calculate average losses\n",
+    "    train_loss = train_loss / len(train_loader)\n",
+    "    valid_loss = valid_loss / len(valid_loader)\n",
+    "    train_loss_list.append(train_loss)\n",
+    "\n",
+    "    # Print training/validation statistics\n",
+    "    print(\n",
+    "        \"Epoch: {} \\tTraining Loss: {:.6f} \\tValidation Loss: {:.6f}\".format(\n",
+    "            epoch, train_loss, valid_loss\n",
+    "        )\n",
+    "    )\n",
+    "\n",
+    "    # Save model if validation loss has decreased\n",
+    "    if valid_loss <= valid_loss_min:\n",
+    "        print(\n",
+    "            \"Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...\".format(\n",
+    "                valid_loss_min, valid_loss\n",
+    "            )\n",
+    "        )\n",
+    "        torch.save(custom_model.state_dict(), \"custom_model2_cifar.pt\")\n",
+    "        valid_loss_min = valid_loss"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "import matplotlib.pyplot as plt\n",
+    "\n",
+    "plt.plot(range(n_epochs), train_loss_list)\n",
+    "plt.xlabel(\"Epoch\")\n",
+    "plt.ylabel(\"Loss\")\n",
+    "plt.title(\"Performance of CustomModel\")\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 35,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Test Loss: 36.168822\n",
+      "\n",
+      "Test Accuracy of airplane: 72% (728/1000)\n",
+      "Test Accuracy of automobile: 83% (836/1000)\n",
+      "Test Accuracy of  bird: 63% (637/1000)\n",
+      "Test Accuracy of   cat: 61% (617/1000)\n",
+      "Test Accuracy of  deer: 63% (635/1000)\n",
+      "Test Accuracy of   dog: 56% (562/1000)\n",
+      "Test Accuracy of  frog: 66% (664/1000)\n",
+      "Test Accuracy of horse: 76% (767/1000)\n",
+      "Test Accuracy of  ship: 87% (876/1000)\n",
+      "Test Accuracy of truck: 74% (743/1000)\n",
+      "\n",
+      "Test Accuracy (Overall): 70% (7065/10000)\n"
+     ]
+    }
+   ],
+   "source": [
+    "custom_model.load_state_dict(torch.load(\"./custom_model2_cifar.pt\"))\n",
+    "\n",
+    "# track test loss\n",
+    "test_loss = 0.0\n",
+    "class_correct = list(0.0 for i in range(10))\n",
+    "class_total = list(0.0 for i in range(10))\n",
+    "\n",
+    "custom_model.eval()\n",
+    "# iterate over test data\n",
+    "for data, target in test_loader:\n",
+    "    # move tensors to GPU if CUDA is available\n",
+    "    if train_on_gpu:\n",
+    "        data, target = data.cuda(), target.cuda()\n",
+    "    # forward pass: compute predicted outputs by passing inputs to the custom_model\n",
+    "    output = custom_model(data)\n",
+    "    # calculate the batch loss\n",
+    "    loss = criterion(output, target)\n",
+    "    # update test loss\n",
+    "    test_loss += loss.item() * data.size(0)\n",
+    "    # convert output probabilities to predicted class\n",
+    "    _, pred = torch.max(output, 1)\n",
+    "    # compare predictions to true label\n",
+    "    correct_tensor = pred.eq(target.data.view_as(pred))\n",
+    "    correct = (\n",
+    "        np.squeeze(correct_tensor.numpy())\n",
+    "        if not train_on_gpu\n",
+    "        else np.squeeze(correct_tensor.cpu().numpy())\n",
+    "    )\n",
+    "    # calculate test accuracy for each object class\n",
+    "    for i in range(batch_size):\n",
+    "        label = target.data[i]\n",
+    "        class_correct[label] += correct[i].item()\n",
+    "        class_total[label] += 1\n",
+    "\n",
+    "# average test loss\n",
+    "test_loss = test_loss / len(test_loader)\n",
+    "print(\"Test Loss: {:.6f}\\n\".format(test_loss))\n",
+    "\n",
+    "for i in range(10):\n",
+    "    if class_total[i] > 0:\n",
+    "        print(\n",
+    "            \"Test Accuracy of %5s: %2d%% (%2d/%2d)\"\n",
+    "            % (\n",
+    "                classes[i],\n",
+    "                100 * class_correct[i] / class_total[i],\n",
+    "                np.sum(class_correct[i]),\n",
+    "                np.sum(class_total[i]),\n",
+    "            )\n",
+    "        )\n",
+    "    else:\n",
+    "        print(\"Test Accuracy of %5s: N/A (no training examples)\" % (classes[i]))\n",
+    "\n",
+    "print(\n",
+    "    \"\\nTest Accuracy (Overall): %2d%% (%2d/%2d)\"\n",
+    "    % (\n",
+    "        100.0 * np.sum(class_correct) / np.sum(class_total),\n",
+    "        np.sum(class_correct),\n",
+    "        np.sum(class_total),\n",
+    "    )\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    ">We observe an increased test accuracy with the modified model"
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "bc381cf4",
@@ -451,23 +992,73 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 49,
    "id": "ef623c26",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "model:  fp32  \t Size (KB): 2365.954\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "2365954"
+      ]
+     },
+     "execution_count": 49,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "import os\n",
     "\n",
     "\n",
     "def print_size_of_model(model, label=\"\"):\n",
-    "    torch.save(model.state_dict(), \"temp.p\")\n",
+    "    torch.save(custom_model.state_dict(), \"temp.p\")\n",
     "    size = os.path.getsize(\"temp.p\")\n",
     "    print(\"model: \", label, \" \\t\", \"Size (KB):\", size / 1e3)\n",
     "    os.remove(\"temp.p\")\n",
     "    return size\n",
     "\n",
     "\n",
-    "print_size_of_model(model, \"fp32\")"
+    "print_size_of_model(custom_model, \"fp32\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 60,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "model:  fp32  \t Size (KB): 2365.954\n",
+      "model:  int8  \t Size (KB): 2365.954\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "2365954"
+      ]
+     },
+     "execution_count": 60,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "first_model = Net()\n",
+    "first_model = first_model.load_state_dict(torch.load(\"./model_cifar.pt\"))\n",
+    "print_size_of_model(first_model, \"fp32\")\n",
+    "quantized_first_model = torch.quantization.quantize_dynamic(custom_model, dtype=torch.qint8)\n",
+    "print_size_of_model(quantized_first_model, \"int8\")"
    ]
   },
   {
@@ -480,15 +1071,33 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 62,
    "id": "c4c65d4b",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "model:  int8  \t Size (KB): 2365.954\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "2365954"
+      ]
+     },
+     "execution_count": 62,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "import torch.quantization\n",
     "\n",
     "\n",
-    "quantized_model = torch.quantization.quantize_dynamic(model, dtype=torch.qint8)\n",
+    "quantized_model = torch.quantization.quantize_dynamic(custom_model, dtype=torch.qint8)\n",
     "print_size_of_model(quantized_model, \"int8\")"
    ]
   },
@@ -500,6 +1109,177 @@
     "For each class, compare the classification test accuracy of the initial model and the quantized model. Also give the overall test accuracy for both models."
    ]
   },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Compute correct classes for the quantized model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 63,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "NotImplementedError",
+     "evalue": "Could not run 'quantized::linear_dynamic' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::linear_dynamic' is only available for these backends: [CPU, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMeta, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].\n\nCPU: registered at ../aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp:662 [kernel]\nBackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]\nPython: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:153 [backend fallback]\nFuncTorchDynamicLayerBackMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]\nFunctionalize: registered at ../aten/src/ATen/FunctionalizeFallbackKernel.cpp:290 [backend fallback]\nNamed: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]\nConjugate: registered at ../aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]\nNegative: registered at ../aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]\nZeroTensor: registered at ../aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]\nADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]\nAutogradOther: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:53 [backend fallback]\nAutogradCPU: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:57 [backend fallback]\nAutogradCUDA: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:65 [backend fallback]\nAutogradXLA: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:69 [backend fallback]\nAutogradMPS: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:77 [backend fallback]\nAutogradXPU: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:61 [backend fallback]\nAutogradHPU: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:90 [backend fallback]\nAutogradLazy: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:73 [backend fallback]\nAutogradMeta: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:81 [backend fallback]\nTracer: registered at ../torch/csrc/autograd/TraceTypeManual.cpp:296 [backend fallback]\nAutocastCPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:382 [backend fallback]\nAutocastCUDA: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:249 [backend fallback]\nFuncTorchBatched: registered at ../aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:710 [backend fallback]\nFuncTorchVmapMode: fallthrough registered at ../aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]\nBatched: registered at ../aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]\nVmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]\nFuncTorchGradWrapper: registered at ../aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]\nPythonTLSSnapshot: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:161 [backend fallback]\nFuncTorchDynamicLayerFrontMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]\nPreDispatch: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:165 [backend fallback]\nPythonDispatcher: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:157 [backend fallback]\n",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNotImplementedError\u001b[0m                       Traceback (most recent call last)",
+      "\u001b[1;32m/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2 Deep Learning.ipynb Cell 38\u001b[0m line \u001b[0;36m1\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2%20Deep%20Learning.ipynb#Y113sZmlsZQ%3D%3D?line=10'>11</a>\u001b[0m     data, target \u001b[39m=\u001b[39m data\u001b[39m.\u001b[39mcuda(), target\u001b[39m.\u001b[39mcuda()\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2%20Deep%20Learning.ipynb#Y113sZmlsZQ%3D%3D?line=11'>12</a>\u001b[0m \u001b[39m# forward pass: compute predicted outputs by passing inputs to the quantized_model\u001b[39;00m\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2%20Deep%20Learning.ipynb#Y113sZmlsZQ%3D%3D?line=12'>13</a>\u001b[0m output \u001b[39m=\u001b[39m quantized_model(data)\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2%20Deep%20Learning.ipynb#Y113sZmlsZQ%3D%3D?line=13'>14</a>\u001b[0m \u001b[39m# calculate the batch loss\u001b[39;00m\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2%20Deep%20Learning.ipynb#Y113sZmlsZQ%3D%3D?line=14'>15</a>\u001b[0m loss \u001b[39m=\u001b[39m criterion(output, target)\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)  \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n",
+      "\u001b[1;32m/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2 Deep Learning.ipynb Cell 38\u001b[0m line \u001b[0;36m2\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2%20Deep%20Learning.ipynb#Y113sZmlsZQ%3D%3D?line=18'>19</a>\u001b[0m x \u001b[39m=\u001b[39m x\u001b[39m.\u001b[39mview(\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, \u001b[39m64\u001b[39m \u001b[39m*\u001b[39m \u001b[39m4\u001b[39m \u001b[39m*\u001b[39m \u001b[39m4\u001b[39m)\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2%20Deep%20Learning.ipynb#Y113sZmlsZQ%3D%3D?line=20'>21</a>\u001b[0m \u001b[39m# print(f'before first fully connected: x.shape = {x.shape}')\u001b[39;00m\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2%20Deep%20Learning.ipynb#Y113sZmlsZQ%3D%3D?line=21'>22</a>\u001b[0m x \u001b[39m=\u001b[39m F\u001b[39m.\u001b[39mrelu(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mfc1(x))\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2%20Deep%20Learning.ipynb#Y113sZmlsZQ%3D%3D?line=22'>23</a>\u001b[0m \u001b[39m# print(f'after first fully connected, x.shape = {x.shape}')\u001b[39;00m\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/vl/Documents/4A/liming_chen_deep/be2/mod_4_6-td2/TD2%20Deep%20Learning.ipynb#Y113sZmlsZQ%3D%3D?line=23'>24</a>\u001b[0m x \u001b[39m=\u001b[39m F\u001b[39m.\u001b[39mrelu(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfc2(x))\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)  \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/ao/nn/quantized/dynamic/modules/linear.py:54\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     51\u001b[0m         Y \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mops\u001b[39m.\u001b[39mquantized\u001b[39m.\u001b[39mlinear_dynamic(\n\u001b[1;32m     52\u001b[0m             x, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_packed_params\u001b[39m.\u001b[39m_packed_params)\n\u001b[1;32m     53\u001b[0m     \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 54\u001b[0m         Y \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mops\u001b[39m.\u001b[39;49mquantized\u001b[39m.\u001b[39;49mlinear_dynamic(\n\u001b[1;32m     55\u001b[0m             x, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_packed_params\u001b[39m.\u001b[39;49m_packed_params, reduce_range\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n\u001b[1;32m     56\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_packed_params\u001b[39m.\u001b[39mdtype \u001b[39m==\u001b[39m torch\u001b[39m.\u001b[39mfloat16:\n\u001b[1;32m     57\u001b[0m     Y \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mops\u001b[39m.\u001b[39mquantized\u001b[39m.\u001b[39mlinear_dynamic_fp16(\n\u001b[1;32m     58\u001b[0m         x, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_packed_params\u001b[39m.\u001b[39m_packed_params)\n",
+      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/_ops.py:692\u001b[0m, in \u001b[0;36mOpOverloadPacket.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    687\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m    688\u001b[0m     \u001b[39m# overloading __call__ to ensure torch.ops.foo.bar()\u001b[39;00m\n\u001b[1;32m    689\u001b[0m     \u001b[39m# is still callable from JIT\u001b[39;00m\n\u001b[1;32m    690\u001b[0m     \u001b[39m# We save the function ptr as the `op` attribute on\u001b[39;00m\n\u001b[1;32m    691\u001b[0m     \u001b[39m# OpOverloadPacket to access it here.\u001b[39;00m\n\u001b[0;32m--> 692\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_op(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs \u001b[39mor\u001b[39;49;00m {})\n",
+      "\u001b[0;31mNotImplementedError\u001b[0m: Could not run 'quantized::linear_dynamic' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::linear_dynamic' is only available for these backends: [CPU, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMeta, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].\n\nCPU: registered at ../aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp:662 [kernel]\nBackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]\nPython: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:153 [backend fallback]\nFuncTorchDynamicLayerBackMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]\nFunctionalize: registered at ../aten/src/ATen/FunctionalizeFallbackKernel.cpp:290 [backend fallback]\nNamed: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]\nConjugate: registered at ../aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]\nNegative: registered at ../aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]\nZeroTensor: registered at ../aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]\nADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]\nAutogradOther: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:53 [backend fallback]\nAutogradCPU: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:57 [backend fallback]\nAutogradCUDA: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:65 [backend fallback]\nAutogradXLA: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:69 [backend fallback]\nAutogradMPS: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:77 [backend fallback]\nAutogradXPU: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:61 [backend fallback]\nAutogradHPU: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:90 [backend fallback]\nAutogradLazy: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:73 [backend fallback]\nAutogradMeta: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:81 [backend fallback]\nTracer: registered at ../torch/csrc/autograd/TraceTypeManual.cpp:296 [backend fallback]\nAutocastCPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:382 [backend fallback]\nAutocastCUDA: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:249 [backend fallback]\nFuncTorchBatched: registered at ../aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:710 [backend fallback]\nFuncTorchVmapMode: fallthrough registered at ../aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]\nBatched: registered at ../aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]\nVmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]\nFuncTorchGradWrapper: registered at ../aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]\nPythonTLSSnapshot: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:161 [backend fallback]\nFuncTorchDynamicLayerFrontMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]\nPreDispatch: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:165 [backend fallback]\nPythonDispatcher: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:157 [backend fallback]\n"
+     ]
+    }
+   ],
+   "source": [
+    "# track test loss\n",
+    "test_loss = 0.0\n",
+    "class_correct_quantized = list(0.0 for i in range(10))\n",
+    "class_total_quantized = list(0.0 for i in range(10))\n",
+    "\n",
+    "quantized_model.eval()\n",
+    "# iterate over test data\n",
+    "for data, target in test_loader:\n",
+    "    # move tensors to GPU if CUDA is available\n",
+    "    if train_on_gpu:\n",
+    "        data, target = data.cuda(), target.cuda()\n",
+    "    # forward pass: compute predicted outputs by passing inputs to the quantized_model\n",
+    "    output = quantized_model(data)\n",
+    "    # calculate the batch loss\n",
+    "    loss = criterion(output, target)\n",
+    "    # update test loss\n",
+    "    test_loss += loss.item() * data.size(0)\n",
+    "    # convert output probabilities to predicted class\n",
+    "    _, pred = torch.max(output, 1)\n",
+    "    # compare predictions to true label\n",
+    "    correct_tensor = pred.eq(target.data.view_as(pred))\n",
+    "    correct = (\n",
+    "        np.squeeze(correct_tensor.numpy())\n",
+    "        if not train_on_gpu\n",
+    "        else np.squeeze(correct_tensor.cpu().numpy())\n",
+    "    )\n",
+    "    # calculate test accuracy for each object class\n",
+    "    for i in range(batch_size):\n",
+    "        label = target.data[i]\n",
+    "        class_correct_quantized[label] += correct[i].item()\n",
+    "        class_total_quantized[label] += 1\n",
+    "\n",
+    "# average test loss\n",
+    "test_loss = test_loss / len(test_loader)\n",
+    "print(\"Test Loss: {:.6f}\\n\".format(test_loss))\n",
+    "\n",
+    "for i in range(10):\n",
+    "    if class_total_quantized[i] > 0:\n",
+    "        print(\n",
+    "            \"Test Accuracy of %5s: %2d%% (%2d/%2d)\"\n",
+    "            % (\n",
+    "                classes[i],\n",
+    "                100 * class_correct_quantized[i] / class_total_quantized[i],\n",
+    "                np.sum(class_correct_quantized[i]),\n",
+    "                np.sum(class_total_quantized[i]),\n",
+    "            )\n",
+    "        )\n",
+    "    else:\n",
+    "        print(\"Test Accuracy of %5s: N/A (no training examples)\" % (classes[i]))\n",
+    "\n",
+    "print(\n",
+    "    \"\\nTest Accuracy (Overall): %2d%% (%2d/%2d)\"\n",
+    "    % (\n",
+    "        100.0 * np.sum(class_correct_quantized) / np.sum(class_total_quantized),\n",
+    "        np.sum(class_correct_quantized),\n",
+    "        np.sum(class_total_quantized),\n",
+    "    )\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 40,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "class_correct = [728.0, 836.0, 637.0, 617.0, 635.0, 562.0, 664.0, 767.0, 876.0, 743.0]\n",
+      "class_total = [1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0]\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(f'class_correct = {class_correct}')\n",
+    "print(f'class_total = {class_total}')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 43,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "test_accuracy_classes = [c_correct/c_total for (c_correct, c_total) in zip(class_correct, class_total)]\n",
+    "test_accuracy_classes_quantized = [c_correct/c_total for (c_correct, c_total) in zip(class_correct_quantized, class_total_quantized)]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 45,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "[0.728, 0.836, 0.637, 0.617, 0.635, 0.562, 0.664, 0.767, 0.876, 0.743]"
+      ]
+     },
+     "execution_count": 45,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "test_accuracy_classes"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 46,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<BarContainer object of 10 artists>"
+      ]
+     },
+     "execution_count": 46,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "plt.bar([i for i in range(1,11)], test_accuracy_classes, color='blue', edgecolor='black')"
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "a0a34b90",
@@ -521,10 +1301,23 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 65,
    "id": "b4d13080",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/hacklexander/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
+      "  warnings.warn(\n",
+      "/home/hacklexander/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.\n",
+      "  warnings.warn(msg)\n",
+      "Downloading: \"https://download.pytorch.org/models/resnet50-0676ba61.pth\" to /home/hacklexander/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth\n",
+      "100.0%"
+     ]
+    }
+   ],
    "source": [
     "import json\n",
     "from PIL import Image\n",
@@ -586,6 +1379,36 @@
     "    \n"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 66,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "model:  fp32  \t Size (KB): 2365.954\n",
+      "model:  int8  \t Size (KB): 2365.954\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "2365954"
+      ]
+     },
+     "execution_count": 66,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "print_size_of_model(model, \"fp32\")\n",
+    "quantized_resnet50 = torch.quantization.quantize_dynamic(model, dtype=torch.qint8)\n",
+    "print_size_of_model(quantized_resnet50, \"int8\")"
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "5d57da4b",
@@ -604,10 +1427,21 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 68,
    "id": "be2d31f5",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
    "source": [
     "import os\n",
     "\n",
@@ -642,7 +1476,7 @@
     "    ),\n",
     "}\n",
     "\n",
-    "data_dir = \"hymenoptera_data\"\n",
+    "data_dir = \"hymenoptera_data/hymenoptera_data\"\n",
     "# Create train and validation datasets and loaders\n",
     "image_datasets = {\n",
     "    x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])\n",
@@ -682,8 +1516,7 @@
     "# Make a grid from batch\n",
     "out = torchvision.utils.make_grid(inputs)\n",
     "\n",
-    "imshow(out, title=[class_names[x] for x in classes])\n",
-    "\n"
+    "imshow(out, title=[class_names[x] for x in classes])"
    ]
   },
   {
@@ -696,10 +1529,102 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 70,
    "id": "572d824c",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/hacklexander/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
+      "  warnings.warn(\n",
+      "/home/hacklexander/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
+      "  warnings.warn(msg)\n",
+      "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /home/hacklexander/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n",
+      "0.1%"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "100.0%\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch 1/10\n",
+      "----------\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/hacklexander/.local/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:136: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n",
+      "  warnings.warn(\"Detected call of `lr_scheduler.step()` before `optimizer.step()`. \"\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "train Loss: 0.6541 Acc: 0.6066\n",
+      "val Loss: 0.3065 Acc: 0.8497\n",
+      "\n",
+      "Epoch 2/10\n",
+      "----------\n",
+      "train Loss: 0.4837 Acc: 0.7910\n",
+      "val Loss: 0.1637 Acc: 0.9477\n",
+      "\n",
+      "Epoch 3/10\n",
+      "----------\n",
+      "train Loss: 0.4998 Acc: 0.7541\n",
+      "val Loss: 0.1664 Acc: 0.9477\n",
+      "\n",
+      "Epoch 4/10\n",
+      "----------\n",
+      "train Loss: 0.3661 Acc: 0.8156\n",
+      "val Loss: 0.4234 Acc: 0.8497\n",
+      "\n",
+      "Epoch 5/10\n",
+      "----------\n",
+      "train Loss: 0.3994 Acc: 0.8197\n",
+      "val Loss: 0.1507 Acc: 0.9542\n",
+      "\n",
+      "Epoch 6/10\n",
+      "----------\n",
+      "train Loss: 0.7117 Acc: 0.7254\n",
+      "val Loss: 0.2423 Acc: 0.9216\n",
+      "\n",
+      "Epoch 7/10\n",
+      "----------\n",
+      "train Loss: 0.2918 Acc: 0.8730\n",
+      "val Loss: 0.1659 Acc: 0.9542\n",
+      "\n",
+      "Epoch 8/10\n",
+      "----------\n",
+      "train Loss: 0.4188 Acc: 0.8156\n",
+      "val Loss: 0.1598 Acc: 0.9477\n",
+      "\n",
+      "Epoch 9/10\n",
+      "----------\n",
+      "train Loss: 0.3808 Acc: 0.8484\n",
+      "val Loss: 0.1795 Acc: 0.9412\n",
+      "\n",
+      "Epoch 10/10\n",
+      "----------\n",
+      "train Loss: 0.4709 Acc: 0.7869\n",
+      "val Loss: 0.1637 Acc: 0.9477\n",
+      "\n",
+      "Training complete in 0m 5s\n",
+      "Best val Acc: 0.954248\n"
+     ]
+    }
+   ],
    "source": [
     "import copy\n",
     "import os\n",
@@ -739,7 +1664,7 @@
     "    ),\n",
     "}\n",
     "\n",
-    "data_dir = \"hymenoptera_data\"\n",
+    "data_dir = \"hymenoptera_data/hymenoptera_data/\"\n",
     "# Create train and validation datasets and loaders\n",
     "image_datasets = {\n",
     "    x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])\n",
@@ -878,7 +1803,7 @@
     "exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)\n",
     "model, epoch_time = train_model(\n",
     "    model, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=10\n",
-    ")\n"
+    ")"
    ]
   },
   {
@@ -897,6 +1822,20 @@
     "Apply ther quantization (post and quantization aware) and evaluate impact on model size and accuracy."
    ]
   },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Modification of eval_model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
   {
    "cell_type": "markdown",
    "id": "04a263f0",
@@ -926,7 +1865,7 @@
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3.8.5 ('base')",
+   "display_name": "Python 3 (ipykernel)",
    "language": "python",
    "name": "python3"
   },
@@ -940,7 +1879,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.8.5"
+   "version": "3.10.12"
   },
   "vscode": {
    "interpreter": {