diff --git a/TD2 Deep Learning.ipynb b/TD2 Deep Learning.ipynb
index 2ecfce959ae6b947b633a758433f9bea0bf6992e..780cc6f59bcac17f63cd9a03502f250b2ca6e700 100644
--- a/TD2 Deep Learning.ipynb	
+++ b/TD2 Deep Learning.ipynb	
@@ -33,12 +33,36 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 1,
    "id": "330a42f5",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Requirement already satisfied: torch in c:\\users\\marin\\anaconda3\\lib\\site-packages (2.0.1)\n",
+      "Requirement already satisfied: torchvision in c:\\users\\marin\\anaconda3\\lib\\site-packages (0.15.2a0)\n",
+      "Requirement already satisfied: filelock in c:\\users\\marin\\anaconda3\\lib\\site-packages (from torch) (3.9.0)\n",
+      "Requirement already satisfied: typing-extensions in c:\\users\\marin\\anaconda3\\lib\\site-packages (from torch) (4.7.1)\n",
+      "Requirement already satisfied: sympy in c:\\users\\marin\\anaconda3\\lib\\site-packages (from torch) (1.11.1)\n",
+      "Requirement already satisfied: networkx in c:\\users\\marin\\anaconda3\\lib\\site-packages (from torch) (3.1)\n",
+      "Requirement already satisfied: jinja2 in c:\\users\\marin\\anaconda3\\lib\\site-packages (from torch) (3.1.2)\n",
+      "Requirement already satisfied: numpy in c:\\users\\marin\\anaconda3\\lib\\site-packages (from torchvision) (1.24.3)\n",
+      "Requirement already satisfied: requests in c:\\users\\marin\\anaconda3\\lib\\site-packages (from torchvision) (2.31.0)\n",
+      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in c:\\users\\marin\\anaconda3\\lib\\site-packages (from torchvision) (9.4.0)\n",
+      "Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\marin\\anaconda3\\lib\\site-packages (from jinja2->torch) (2.1.1)\n",
+      "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\marin\\anaconda3\\lib\\site-packages (from requests->torchvision) (2.0.4)\n",
+      "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\marin\\anaconda3\\lib\\site-packages (from requests->torchvision) (3.4)\n",
+      "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\marin\\anaconda3\\lib\\site-packages (from requests->torchvision) (1.26.16)\n",
+      "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\marin\\anaconda3\\lib\\site-packages (from requests->torchvision) (2023.7.22)\n",
+      "Requirement already satisfied: mpmath>=0.19 in c:\\users\\marin\\anaconda3\\lib\\site-packages (from sympy->torch) (1.3.0)\n",
+      "Note: you may need to restart the kernel to use updated packages.\n"
+     ]
+    }
+   ],
    "source": [
-    "%pip install torch torchvision"
+    "%pip install torch torchvision\n"
    ]
   },
   {
@@ -52,10 +76,72 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 2,
    "id": "b1950f0a",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "tensor([[ 0.1899,  0.8863,  0.1413, -0.1998,  0.1497,  0.6815, -0.8749, -0.3087,\n",
+      "         -0.0238,  0.4644],\n",
+      "        [-1.0627, -0.3343,  0.9708, -1.1139,  0.3306,  0.0214, -0.8496, -0.0632,\n",
+      "         -1.1719, -1.7125],\n",
+      "        [ 0.1654,  0.8615, -0.4635,  1.1866, -0.4576,  0.3541, -0.0137,  0.8904,\n",
+      "          2.0490,  0.0841],\n",
+      "        [-0.6709,  0.1347,  0.6463,  0.8951, -0.7332,  0.0309,  0.2212,  0.6717,\n",
+      "          0.8933,  0.1584],\n",
+      "        [-0.3432, -0.4148, -0.8095, -0.1408,  0.5006, -1.1312,  0.3924,  0.4016,\n",
+      "         -0.5985, -0.1117],\n",
+      "        [-1.0762, -0.1045,  0.0249,  0.8469,  1.0076, -1.4642,  0.4522,  0.0082,\n",
+      "          0.7456,  0.7806],\n",
+      "        [-0.0753,  0.8675,  1.1765,  0.1967, -1.1167,  1.3006,  0.8635,  0.0400,\n",
+      "          1.0068,  0.8430],\n",
+      "        [ 0.1805, -1.4230,  0.8074, -0.3967,  1.5681, -1.2731, -1.2154, -1.3516,\n",
+      "         -1.3917, -0.2232],\n",
+      "        [-1.9622, -0.5655, -0.7118,  0.6445, -0.7508,  0.3790, -1.9274,  2.8144,\n",
+      "         -0.1963,  0.7060],\n",
+      "        [ 0.1464, -1.2219, -0.5618,  0.0519,  0.5780,  0.0497, -0.1709, -0.7162,\n",
+      "         -0.0512, -0.2961],\n",
+      "        [-1.1464, -1.7522,  0.4518, -0.7085, -0.3393, -0.9789,  0.8045,  0.4721,\n",
+      "         -0.6035, -0.6996],\n",
+      "        [ 1.2104,  0.4869, -0.4659,  1.3424,  0.4500, -1.6684,  0.1359, -0.2354,\n",
+      "         -0.6425,  0.4473],\n",
+      "        [-0.3536, -0.5641, -1.4005, -1.4136, -0.2599, -0.6156,  0.9142, -0.6475,\n",
+      "         -0.1155,  1.0220],\n",
+      "        [-0.9314,  0.9504,  0.6591,  1.6823, -0.7177, -0.9853, -1.7366, -1.0150,\n",
+      "         -2.3821,  0.3625]])\n",
+      "AlexNet(\n",
+      "  (features): Sequential(\n",
+      "    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n",
+      "    (1): ReLU(inplace=True)\n",
+      "    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+      "    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
+      "    (4): ReLU(inplace=True)\n",
+      "    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+      "    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+      "    (7): ReLU(inplace=True)\n",
+      "    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+      "    (9): ReLU(inplace=True)\n",
+      "    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+      "    (11): ReLU(inplace=True)\n",
+      "    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+      "  )\n",
+      "  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n",
+      "  (classifier): Sequential(\n",
+      "    (0): Dropout(p=0.5, inplace=False)\n",
+      "    (1): Linear(in_features=9216, out_features=4096, bias=True)\n",
+      "    (2): ReLU(inplace=True)\n",
+      "    (3): Dropout(p=0.5, inplace=False)\n",
+      "    (4): Linear(in_features=4096, out_features=4096, bias=True)\n",
+      "    (5): ReLU(inplace=True)\n",
+      "    (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
+      "  )\n",
+      ")\n"
+     ]
+    }
+   ],
    "source": [
     "import torch\n",
     "\n",
@@ -95,10 +181,18 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 3,
    "id": "6e18f2fd",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "CUDA is not available.  Training on CPU ...\n"
+     ]
+    }
+   ],
    "source": [
     "import torch\n",
     "\n",
@@ -121,10 +215,19 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 30,
    "id": "462666a2",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Files already downloaded and verified\n",
+      "Files already downloaded and verified\n"
+     ]
+    }
+   ],
    "source": [
     "import numpy as np\n",
     "from torchvision import datasets, transforms\n",
@@ -193,10 +296,25 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 15,
    "id": "317bf070",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Net(\n",
+      "  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))\n",
+      "  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+      "  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
+      "  (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
+      "  (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
+      "  (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
+      ")\n"
+     ]
+    }
+   ],
    "source": [
     "import torch.nn as nn\n",
     "import torch.nn.functional as F\n",
@@ -242,10 +360,27 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 29,
    "id": "4b53f229",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch: 0 \tTraining Loss: 12.224831 \tValidation Loss: 19.767630\n",
+      "Validation loss decreased (inf --> 19.767630).  Saving model ...\n",
+      "Epoch: 1 \tTraining Loss: 11.779677 \tValidation Loss: 19.502304\n",
+      "Validation loss decreased (19.767630 --> 19.502304).  Saving model ...\n",
+      "Epoch: 2 \tTraining Loss: 11.402885 \tValidation Loss: 19.299951\n",
+      "Validation loss decreased (19.502304 --> 19.299951).  Saving model ...\n",
+      "Epoch: 3 \tTraining Loss: 10.922978 \tValidation Loss: 20.018273\n",
+      "Epoch: 4 \tTraining Loss: 10.535478 \tValidation Loss: 20.801044\n",
+      "Epoch: 5 \tTraining Loss: 10.175075 \tValidation Loss: 20.377939\n",
+      "Early stopping after 5 epochss.\n"
+     ]
+    }
+   ],
    "source": [
     "import torch.optim as optim\n",
     "\n",
@@ -254,8 +389,11 @@
     "\n",
     "n_epochs = 30  # number of epochs to train the model\n",
     "train_loss_list = []  # list to store loss to visualize\n",
+    "valid_loss_list = [] # We track validation loss to check for overfitting\n",
     "valid_loss_min = np.Inf  # track change in validation loss\n",
     "\n",
+    "patience = 3 # We stop the training if loss doesn't improve for 3 consecutive epochs\n",
+    "\n",
     "for epoch in range(n_epochs):\n",
     "    # Keep track of training and validation loss\n",
     "    train_loss = 0.0\n",
@@ -297,7 +435,9 @@
     "    train_loss = train_loss / len(train_loader)\n",
     "    valid_loss = valid_loss / len(valid_loader)\n",
     "    train_loss_list.append(train_loss)\n",
+    "    valid_loss_list.append(valid_loss)\n",
     "\n",
+    "    \n",
     "    # Print training/validation statistics\n",
     "    print(\n",
     "        \"Epoch: {} \\tTraining Loss: {:.6f} \\tValidation Loss: {:.6f}\".format(\n",
@@ -313,7 +453,14 @@
     "            )\n",
     "        )\n",
     "        torch.save(model.state_dict(), \"model_cifar.pt\")\n",
-    "        valid_loss_min = valid_loss"
+    "        valid_loss_min = valid_loss\n",
+    "        patience_counter = 0\n",
+    "    else:\n",
+    "        patience_counter += 1\n",
+    "        \n",
+    "    if patience_counter >= patience:\n",
+    "        print(f\"Early stopping after {epoch} epochss.\")\n",
+    "        break"
    ]
   },
   {
@@ -326,16 +473,29 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 19,
    "id": "d39df818",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
    "source": [
     "import matplotlib.pyplot as plt\n",
     "\n",
-    "plt.plot(range(n_epochs), train_loss_list)\n",
+    "plt.plot(range(len(train_loss_list)), train_loss_list, label=\"train_loss\")\n",
+    "plt.plot(range(len(valid_loss_list)), valid_loss_list, label=\"validation_loss\")\n",
     "plt.xlabel(\"Epoch\")\n",
     "plt.ylabel(\"Loss\")\n",
+    "plt.legend()\n",
     "plt.title(\"Performance of Model 1\")\n",
     "plt.show()"
    ]
@@ -350,10 +510,31 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 20,
    "id": "e93efdfc",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Test Loss: 21.403010\n",
+      "\n",
+      "Test Accuracy of airplane: 68% (681/1000)\n",
+      "Test Accuracy of automobile: 74% (744/1000)\n",
+      "Test Accuracy of  bird: 50% (508/1000)\n",
+      "Test Accuracy of   cat: 37% (371/1000)\n",
+      "Test Accuracy of  deer: 56% (566/1000)\n",
+      "Test Accuracy of   dog: 48% (480/1000)\n",
+      "Test Accuracy of  frog: 78% (780/1000)\n",
+      "Test Accuracy of horse: 72% (724/1000)\n",
+      "Test Accuracy of  ship: 74% (748/1000)\n",
+      "Test Accuracy of truck: 72% (724/1000)\n",
+      "\n",
+      "Test Accuracy (Overall): 63% (6326/10000)\n"
+     ]
+    }
+   ],
    "source": [
     "model.load_state_dict(torch.load(\"./model_cifar.pt\"))\n",
     "\n",
@@ -419,9 +600,11 @@
   },
   {
    "cell_type": "markdown",
-   "id": "944991a2",
+   "id": "491cc760",
    "metadata": {},
    "source": [
+    "### Creating a new network\n",
+    "\n",
     "Build a new network with the following structure.\n",
     "\n",
     "- It has 3 convolutional layers of kernel size 3 and padding of 1.\n",
@@ -434,6 +617,380 @@
     "Compare the results obtained with this new network to those obtained previously."
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "id": "43fff7d9",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Net(\n",
+      "  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+      "  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+      "  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+      "  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+      "  (fc1): Linear(in_features=1024, out_features=512, bias=True)\n",
+      "  (fc2): Linear(in_features=512, out_features=64, bias=True)\n",
+      "  (fc3): Linear(in_features=64, out_features=10, bias=True)\n",
+      ")\n"
+     ]
+    }
+   ],
+   "source": [
+    "# We create the new model\n",
+    "\n",
+    "class Net(nn.Module):\n",
+    "    def __init__(self):\n",
+    "        super(Net, self).__init__()\n",
+    "        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)\n",
+    "        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)\n",
+    "        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)\n",
+    "        \n",
+    "        self.pool = nn.MaxPool2d(2, 2)\n",
+    "        \n",
+    "        self.fc1 = nn.Linear(64 * 4 * 4, 512)\n",
+    "        self.fc2 = nn.Linear(512, 64)\n",
+    "        self.fc3 = nn.Linear(64, 10)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        x = self.pool(F.relu(self.conv1(x)))\n",
+    "        x = self.pool(F.relu(self.conv2(x)))\n",
+    "        x = self.pool(F.relu(self.conv3(x)))\n",
+    "        x = x.view(-1, 64 * 4 * 4)\n",
+    "        x = F.dropout(F.relu(self.fc1(x)), p=0.5)\n",
+    "        x = F.dropout(F.relu(self.fc2(x)), p=0.5)\n",
+    "        x = self.fc3(x)\n",
+    "        return x\n",
+    "\n",
+    "\n",
+    "# create a complete CNN\n",
+    "model2 = Net()\n",
+    "print(model2)\n",
+    "# move tensors to GPU if CUDA is available\n",
+    "if train_on_gpu:\n",
+    "    model2.cuda()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "109ca0c2",
+   "metadata": {},
+   "source": [
+    "## Training the new model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "id": "40638ce8",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch: 0 \tTraining Loss: 45.992308 \tValidation Loss: 45.734079\n",
+      "Validation loss decreased (inf --> 45.734079).  Saving model ...\n",
+      "Epoch: 1 \tTraining Loss: 42.861547 \tValidation Loss: 39.263893\n",
+      "Validation loss decreased (45.734079 --> 39.263893).  Saving model ...\n",
+      "Epoch: 2 \tTraining Loss: 36.771661 \tValidation Loss: 34.267659\n",
+      "Validation loss decreased (39.263893 --> 34.267659).  Saving model ...\n",
+      "Epoch: 3 \tTraining Loss: 33.214925 \tValidation Loss: 32.355083\n",
+      "Validation loss decreased (34.267659 --> 32.355083).  Saving model ...\n",
+      "Epoch: 4 \tTraining Loss: 31.095436 \tValidation Loss: 30.169813\n",
+      "Validation loss decreased (32.355083 --> 30.169813).  Saving model ...\n",
+      "Epoch: 5 \tTraining Loss: 29.427925 \tValidation Loss: 28.723804\n",
+      "Validation loss decreased (30.169813 --> 28.723804).  Saving model ...\n",
+      "Epoch: 6 \tTraining Loss: 27.979727 \tValidation Loss: 27.368319\n",
+      "Validation loss decreased (28.723804 --> 27.368319).  Saving model ...\n",
+      "Epoch: 7 \tTraining Loss: 26.560045 \tValidation Loss: 26.531596\n",
+      "Validation loss decreased (27.368319 --> 26.531596).  Saving model ...\n",
+      "Epoch: 8 \tTraining Loss: 25.233062 \tValidation Loss: 25.151196\n",
+      "Validation loss decreased (26.531596 --> 25.151196).  Saving model ...\n",
+      "Epoch: 9 \tTraining Loss: 23.933243 \tValidation Loss: 24.118065\n",
+      "Validation loss decreased (25.151196 --> 24.118065).  Saving model ...\n",
+      "Epoch: 10 \tTraining Loss: 22.824302 \tValidation Loss: 23.136555\n",
+      "Validation loss decreased (24.118065 --> 23.136555).  Saving model ...\n",
+      "Epoch: 11 \tTraining Loss: 21.638535 \tValidation Loss: 22.276559\n",
+      "Validation loss decreased (23.136555 --> 22.276559).  Saving model ...\n",
+      "Epoch: 12 \tTraining Loss: 20.643672 \tValidation Loss: 21.490677\n",
+      "Validation loss decreased (22.276559 --> 21.490677).  Saving model ...\n",
+      "Epoch: 13 \tTraining Loss: 19.723929 \tValidation Loss: 20.878862\n",
+      "Validation loss decreased (21.490677 --> 20.878862).  Saving model ...\n",
+      "Epoch: 14 \tTraining Loss: 18.881656 \tValidation Loss: 20.151909\n",
+      "Validation loss decreased (20.878862 --> 20.151909).  Saving model ...\n",
+      "Epoch: 15 \tTraining Loss: 18.077398 \tValidation Loss: 22.761932\n",
+      "Epoch: 16 \tTraining Loss: 17.244630 \tValidation Loss: 20.172645\n",
+      "Epoch: 17 \tTraining Loss: 16.708238 \tValidation Loss: 19.282629\n",
+      "Validation loss decreased (20.151909 --> 19.282629).  Saving model ...\n",
+      "Epoch: 18 \tTraining Loss: 16.049521 \tValidation Loss: 19.141060\n",
+      "Validation loss decreased (19.282629 --> 19.141060).  Saving model ...\n",
+      "Epoch: 19 \tTraining Loss: 15.306451 \tValidation Loss: 18.852022\n",
+      "Validation loss decreased (19.141060 --> 18.852022).  Saving model ...\n",
+      "Epoch: 20 \tTraining Loss: 14.730358 \tValidation Loss: 18.173483\n",
+      "Validation loss decreased (18.852022 --> 18.173483).  Saving model ...\n",
+      "Epoch: 21 \tTraining Loss: 14.233885 \tValidation Loss: 18.950112\n",
+      "Epoch: 22 \tTraining Loss: 13.598710 \tValidation Loss: 19.136868\n",
+      "Epoch: 23 \tTraining Loss: 13.122754 \tValidation Loss: 18.415266\n",
+      "Early stopping after 23 epochss.\n"
+     ]
+    }
+   ],
+   "source": [
+    "import torch.optim as optim\n",
+    "\n",
+    "criterion = nn.CrossEntropyLoss()  # specify loss function\n",
+    "optimizer = optim.SGD(model2.parameters(), lr=0.01)  # specify optimizer\n",
+    "\n",
+    "n_epochs = 30  # number of epochs to train the model\n",
+    "train_loss_list = []  # list to store loss to visualize\n",
+    "valid_loss_list = [] # We track validation loss to check for overfitting\n",
+    "valid_loss_min = np.Inf  # track change in validation loss\n",
+    "\n",
+    "patience = 3 # We stop the training if loss doesn't improve for 3 consecutive epochs\n",
+    "\n",
+    "for epoch in range(n_epochs):\n",
+    "    # Keep track of training and validation loss\n",
+    "    train_loss = 0.0\n",
+    "    valid_loss = 0.0\n",
+    "\n",
+    "    # Train the model\n",
+    "    model2.train()\n",
+    "    for data, target in train_loader:\n",
+    "        # Move tensors to GPU if CUDA is available\n",
+    "        if train_on_gpu:\n",
+    "            data, target = data.cuda(), target.cuda()\n",
+    "        # Clear the gradients of all optimized variables\n",
+    "        optimizer.zero_grad()\n",
+    "        # Forward pass: compute predicted outputs by passing inputs to the model\n",
+    "        output = model2(data)\n",
+    "        # Calculate the batch loss\n",
+    "        loss = criterion(output, target)\n",
+    "        # Backward pass: compute gradient of the loss with respect to model parameters\n",
+    "        loss.backward()\n",
+    "        # Perform a single optimization step (parameter update)\n",
+    "        optimizer.step()\n",
+    "        # Update training loss\n",
+    "        train_loss += loss.item() * data.size(0)\n",
+    "\n",
+    "    # Validate the model\n",
+    "    model2.eval()\n",
+    "    for data, target in valid_loader:\n",
+    "        # Move tensors to GPU if CUDA is available\n",
+    "        if train_on_gpu:\n",
+    "            data, target = data.cuda(), target.cuda()\n",
+    "        # Forward pass: compute predicted outputs by passing inputs to the model\n",
+    "        output = model2(data)\n",
+    "        # Calculate the batch loss\n",
+    "        loss = criterion(output, target)\n",
+    "        # Update average validation loss\n",
+    "        valid_loss += loss.item() * data.size(0)\n",
+    "\n",
+    "    # Calculate average losses\n",
+    "    train_loss = train_loss / len(train_loader)\n",
+    "    valid_loss = valid_loss / len(valid_loader)\n",
+    "    train_loss_list.append(train_loss)\n",
+    "    valid_loss_list.append(valid_loss)\n",
+    "\n",
+    "    \n",
+    "    # Print training/validation statistics\n",
+    "    print(\n",
+    "        \"Epoch: {} \\tTraining Loss: {:.6f} \\tValidation Loss: {:.6f}\".format(\n",
+    "            epoch, train_loss, valid_loss\n",
+    "        )\n",
+    "    )\n",
+    "\n",
+    "    # Save model if validation loss has decreased\n",
+    "    if valid_loss <= valid_loss_min:\n",
+    "        print(\n",
+    "            \"Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...\".format(\n",
+    "                valid_loss_min, valid_loss\n",
+    "            )\n",
+    "        )\n",
+    "        torch.save(model2.state_dict(), \"model_cifar_exo1.pt\")\n",
+    "        valid_loss_min = valid_loss\n",
+    "        patience_counter = 0\n",
+    "    else:\n",
+    "        patience_counter += 1\n",
+    "        \n",
+    "    if patience_counter >= patience:\n",
+    "        print(f\"Early stopping after {epoch} epochss.\")\n",
+    "        break"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "d0e23d94",
+   "metadata": {},
+   "source": [
+    "## Checking performance"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 33,
+   "id": "206bc2a1",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "plt.plot(range(len(train_loss_list)), train_loss_list, label=\"train_loss\")\n",
+    "plt.plot(range(len(valid_loss_list)), valid_loss_list, label=\"validation_loss\")\n",
+    "plt.xlabel(\"Epoch\")\n",
+    "plt.ylabel(\"Loss\")\n",
+    "plt.legend()\n",
+    "plt.title(\"Performance of the New Model \")\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 35,
+   "id": "b0fbfa80",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Test Loss: 18.870971\n",
+      "\n",
+      "Test Accuracy of airplane: 75% (758/1000)\n",
+      "Test Accuracy of automobile: 85% (850/1000)\n",
+      "Test Accuracy of  bird: 62% (625/1000)\n",
+      "Test Accuracy of   cat: 39% (394/1000)\n",
+      "Test Accuracy of  deer: 62% (628/1000)\n",
+      "Test Accuracy of   dog: 59% (592/1000)\n",
+      "Test Accuracy of  frog: 81% (812/1000)\n",
+      "Test Accuracy of horse: 71% (714/1000)\n",
+      "Test Accuracy of  ship: 81% (813/1000)\n",
+      "Test Accuracy of truck: 74% (744/1000)\n",
+      "\n",
+      "Test Accuracy (Overall): 69% (6930/10000)\n"
+     ]
+    }
+   ],
+   "source": [
+    "model2.load_state_dict(torch.load(\"./model_cifar_exo1.pt\"))\n",
+    "\n",
+    "# track test loss\n",
+    "test_loss = 0.0\n",
+    "class_correct2 = list(0.0 for i in range(10))\n",
+    "class_total2 = list(0.0 for i in range(10))\n",
+    "\n",
+    "model2.eval()\n",
+    "# iterate over test data\n",
+    "for data, target in test_loader:\n",
+    "    # move tensors to GPU if CUDA is available\n",
+    "    if train_on_gpu:\n",
+    "        data, target = data.cuda(), target.cuda()\n",
+    "    # forward pass: compute predicted outputs by passing inputs to the model\n",
+    "    output = model2(data)\n",
+    "    # calculate the batch loss\n",
+    "    loss = criterion(output, target)\n",
+    "    # update test loss\n",
+    "    test_loss += loss.item() * data.size(0)\n",
+    "    # convert output probabilities to predicted class\n",
+    "    _, pred = torch.max(output, 1)\n",
+    "    # compare predictions to true label\n",
+    "    correct_tensor = pred.eq(target.data.view_as(pred))\n",
+    "    correct = (\n",
+    "        np.squeeze(correct_tensor.numpy())\n",
+    "        if not train_on_gpu\n",
+    "        else np.squeeze(correct_tensor.cpu().numpy())\n",
+    "    )\n",
+    "    # calculate test accuracy for each object class\n",
+    "    for i in range(batch_size):\n",
+    "        label = target.data[i]\n",
+    "        class_correct2[label] += correct[i].item()\n",
+    "        class_total2[label] += 1\n",
+    "\n",
+    "# average test loss\n",
+    "test_loss = test_loss / len(test_loader)\n",
+    "print(\"Test Loss: {:.6f}\\n\".format(test_loss))\n",
+    "\n",
+    "for i in range(10):\n",
+    "    if class_total2[i] > 0:\n",
+    "        print(\n",
+    "            \"Test Accuracy of %5s: %2d%% (%2d/%2d)\"\n",
+    "            % (\n",
+    "                classes[i],\n",
+    "                100 * class_correct2[i] / class_total2[i],\n",
+    "                np.sum(class_correct2[i]),\n",
+    "                np.sum(class_total2[i]),\n",
+    "            )\n",
+    "        )\n",
+    "    else:\n",
+    "        print(\"Test Accuracy of %5s: N/A (no training examples)\" % (classes[i]))\n",
+    "\n",
+    "print(\n",
+    "    \"\\nTest Accuracy (Overall): %2d%% (%2d/%2d)\"\n",
+    "    % (\n",
+    "        100.0 * np.sum(class_correct2) / np.sum(class_total2),\n",
+    "        np.sum(class_correct2),\n",
+    "        np.sum(class_total2),\n",
+    "    )\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "efb5df69",
+   "metadata": {},
+   "source": [
+    "We compare the results obtained with this new network to those obtained previously. Our new model is better than the first model on almost all classes. We get XX% of accuracy on average over all the classes."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 38,
+   "id": "86f7cf40",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "accuracy_initial = [correct / total for correct, total in zip(class_correct, class_total)]\n",
+    "accuracy_new = [correct / total for correct, total in zip(class_correct2, class_total2)]\n",
+    "\n",
+    "x = np.arange(len(classes))\n",
+    "width = .35\n",
+    "\n",
+    "fig, ax = plt.subplots()\n",
+    "r1 = ax.bar(x - width/2, accuracy_initial, width, label=\"Initial Model\")\n",
+    "r2 = ax.bar(x + width/2, accuracy_new, width, label=\"New Model\")\n",
+    "\n",
+    "ax.set_xlabel('Classes')\n",
+    "ax.set_ylabel('Accuracy')\n",
+    "ax.set_title('Accuracy comparison by class')\n",
+    "ax.set_xticks(x)\n",
+    "ax.set_xticklabels(classes)\n",
+    "ax.legend()\n",
+    "plt.show()"
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "bc381cf4",
@@ -883,7 +1440,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "bbd48800",
+   "id": "0316c65c",
    "metadata": {},
    "source": [
     "Experiments:\n",
@@ -926,7 +1483,7 @@
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3.8.5 ('base')",
+   "display_name": "Python 3 (ipykernel)",
    "language": "python",
    "name": "python3"
   },
@@ -940,7 +1497,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.8.5"
+   "version": "3.11.5"
   },
   "vscode": {
    "interpreter": {