From 4d79ee17857e9be3f219b9529c8639a7b38a7559 Mon Sep 17 00:00:00 2001
From: pmarin72 <75830392+pmarin72@users.noreply.github.com>
Date: Thu, 30 Nov 2023 18:04:53 +0100
Subject: [PATCH] some improvements on ex 2

---
 TD2 Deep Learning.ipynb | 329 +++++++++++++++++++++-------------------
 1 file changed, 176 insertions(+), 153 deletions(-)

diff --git a/TD2 Deep Learning.ipynb b/TD2 Deep Learning.ipynb
index d4132b0..ad0ccd3 100644
--- a/TD2 Deep Learning.ipynb	
+++ b/TD2 Deep Learning.ipynb	
@@ -33,7 +33,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": 2,
    "id": "330a42f5",
    "metadata": {},
    "outputs": [
@@ -55,7 +55,7 @@
       "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\marin\\anaconda3\\lib\\site-packages (from requests->torchvision) (2.0.4)\n",
       "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\marin\\anaconda3\\lib\\site-packages (from requests->torchvision) (3.4)\n",
       "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\marin\\anaconda3\\lib\\site-packages (from requests->torchvision) (1.26.16)\n",
-      "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\marin\\anaconda3\\lib\\site-packages (from requests->torchvision) (2023.7.22)\n",
+      "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\marin\\anaconda3\\lib\\site-packages (from requests->torchvision) (2023.11.17)\n",
       "Requirement already satisfied: mpmath>=0.19 in c:\\users\\marin\\anaconda3\\lib\\site-packages (from sympy->torch) (1.3.0)\n",
       "Note: you may need to restart the kernel to use updated packages.\n"
      ]
@@ -76,7 +76,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 3,
    "id": "b1950f0a",
    "metadata": {},
    "outputs": [
@@ -84,34 +84,34 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "tensor([[ 0.5435, -0.4544, -0.1205,  0.4673,  0.1775, -0.7653, -0.0387,  0.0637,\n",
-      "         -0.1312, -1.8514],\n",
-      "        [-0.7451, -0.0433,  0.5996, -0.1318,  0.7856,  0.5377, -1.0717,  0.4718,\n",
-      "         -0.4624,  1.1403],\n",
-      "        [ 1.4221,  0.4673, -0.1187,  0.9117, -1.7922, -0.2721,  0.0516,  1.3558,\n",
-      "         -0.8833, -1.0339],\n",
-      "        [-1.4286, -1.3892, -1.4825,  1.7989,  0.1235,  1.3108, -1.5860,  0.6306,\n",
-      "          0.3286, -0.5756],\n",
-      "        [ 0.2860, -0.4381, -0.8007, -0.9251, -2.4581,  0.6307, -0.9971, -1.0066,\n",
-      "          0.8453, -0.6403],\n",
-      "        [-1.7802, -0.5362,  0.5685,  0.0599,  0.1256,  0.2542, -0.4363, -0.9823,\n",
-      "          0.4746,  1.6888],\n",
-      "        [-1.6597,  1.0951,  0.9582, -1.5032,  1.1591,  0.8159, -1.4805, -0.5566,\n",
-      "          0.4475, -1.6350],\n",
-      "        [-0.4390, -0.5932,  0.6092,  1.7203,  0.4294, -2.0137,  1.5183, -0.0681,\n",
-      "          2.6924,  2.2244],\n",
-      "        [-0.7404,  1.5136,  0.6477, -0.4592,  1.6904, -0.4243,  1.2477, -0.7878,\n",
-      "          0.4548,  0.4966],\n",
-      "        [-0.5082,  0.0487,  0.4923, -0.0613, -1.0030,  0.3108, -0.7571, -0.3653,\n",
-      "          0.4734,  0.3244],\n",
-      "        [-1.5140,  0.9956, -0.5122,  0.2580, -0.4591, -0.2065,  1.3851,  0.2364,\n",
-      "          0.5900, -0.7037],\n",
-      "        [ 1.0271, -1.3211,  0.0545,  0.5302, -0.2711,  0.6698,  2.2225,  0.2634,\n",
-      "         -0.2574, -1.6689],\n",
-      "        [ 0.3319, -1.2073,  0.3785,  1.5544,  0.4043, -1.0159,  0.1956, -1.7744,\n",
-      "          0.3340,  0.8643],\n",
-      "        [ 2.8115, -0.5446,  0.5140, -0.3576, -1.2501, -0.2065, -0.3383,  0.2077,\n",
-      "         -0.2065, -1.2150]])\n",
+      "tensor([[-0.3740, -0.1337, -1.1278,  1.0594, -0.2462, -1.1751, -0.1005,  1.1031,\n",
+      "          0.5354, -0.1985],\n",
+      "        [ 0.3067, -1.0501, -1.3315, -2.7529, -1.8386, -1.0362, -0.8983,  0.4816,\n",
+      "         -0.7046,  0.2330],\n",
+      "        [-0.1206,  0.8951, -0.6436, -0.3075,  0.9056, -0.6875, -0.7694, -0.2017,\n",
+      "          0.9787, -0.5610],\n",
+      "        [-2.8668, -0.3878,  0.7541, -1.1662,  0.4237,  1.1266, -0.3558,  0.1105,\n",
+      "         -1.0558, -1.8606],\n",
+      "        [ 0.3507, -0.8552, -0.9354,  0.3753,  1.2805, -0.3248, -0.4088, -0.5620,\n",
+      "          0.1417,  1.0160],\n",
+      "        [-0.7317, -3.4209, -0.4999, -0.1847,  0.1923,  0.7617,  0.2245, -1.9357,\n",
+      "          0.0595,  2.1604],\n",
+      "        [ 0.1924, -1.1935,  0.9019,  1.2187, -1.7188, -0.7759, -1.3686,  0.2335,\n",
+      "          0.3900,  0.9486],\n",
+      "        [ 1.4248,  0.9080,  0.3575,  1.9698, -0.3119,  1.1467,  1.9559,  1.9424,\n",
+      "         -0.1275, -0.0842],\n",
+      "        [ 0.9739,  1.7380, -0.3301, -0.3293,  0.0384, -0.9268, -1.0350, -0.6020,\n",
+      "         -1.3752,  0.6666],\n",
+      "        [-0.3124, -0.3678, -1.8143, -0.0260, -0.6726, -0.6671, -0.1143,  0.5844,\n",
+      "         -0.8527,  0.7353],\n",
+      "        [ 1.0313,  0.3691,  0.8323, -0.4683, -1.4537, -0.5249,  2.0043,  0.0210,\n",
+      "          0.7745,  1.2210],\n",
+      "        [ 0.7038, -0.3010,  1.8068,  1.0899,  1.9105,  0.3594,  0.7311,  0.8623,\n",
+      "          0.5980, -1.0860],\n",
+      "        [ 1.0422,  1.6860,  0.1746, -0.7042,  0.9685,  1.8207,  0.5156, -0.8631,\n",
+      "          0.8923,  0.5413],\n",
+      "        [-1.7911,  0.4318, -0.6459, -1.6303, -1.9783,  0.9335, -0.2233,  0.9090,\n",
+      "         -1.0225, -0.0549]])\n",
       "AlexNet(\n",
       "  (features): Sequential(\n",
       "    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n",
@@ -181,7 +181,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 4,
    "id": "6e18f2fd",
    "metadata": {},
    "outputs": [
@@ -215,7 +215,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 5,
    "id": "462666a2",
    "metadata": {},
    "outputs": [
@@ -296,7 +296,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 7,
    "id": "317bf070",
    "metadata": {},
    "outputs": [
@@ -358,9 +358,17 @@
     "Loss function and training using SGD (Stochastic Gradient Descent) optimizer"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "083fc25a",
+   "metadata": {},
+   "source": [
+    "We add a counter to do an early stopping if overfit occur. If validation loss doesn't decrease for 3 consecutives epochs we stop the training.  "
+   ]
+  },
   {
    "cell_type": "code",
-   "execution_count": 49,
+   "execution_count": 8,
    "id": "4b53f229",
    "metadata": {},
    "outputs": [
@@ -368,33 +376,34 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Epoch: 0 \tTraining Loss: 43.971403 \tValidation Loss: 38.781364\n",
-      "Validation loss decreased (inf --> 38.781364).  Saving model ...\n",
-      "Epoch: 1 \tTraining Loss: 34.860234 \tValidation Loss: 32.887555\n",
-      "Validation loss decreased (38.781364 --> 32.887555).  Saving model ...\n",
-      "Epoch: 2 \tTraining Loss: 30.713754 \tValidation Loss: 30.105906\n",
-      "Validation loss decreased (32.887555 --> 30.105906).  Saving model ...\n",
-      "Epoch: 3 \tTraining Loss: 28.186929 \tValidation Loss: 28.076022\n",
-      "Validation loss decreased (30.105906 --> 28.076022).  Saving model ...\n",
-      "Epoch: 4 \tTraining Loss: 26.351580 \tValidation Loss: 26.367747\n",
-      "Validation loss decreased (28.076022 --> 26.367747).  Saving model ...\n",
-      "Epoch: 5 \tTraining Loss: 24.947239 \tValidation Loss: 26.368494\n",
-      "Epoch: 6 \tTraining Loss: 23.778135 \tValidation Loss: 24.579198\n",
-      "Validation loss decreased (26.367747 --> 24.579198).  Saving model ...\n",
-      "Epoch: 7 \tTraining Loss: 22.707515 \tValidation Loss: 24.203169\n",
-      "Validation loss decreased (24.579198 --> 24.203169).  Saving model ...\n",
-      "Epoch: 8 \tTraining Loss: 21.805590 \tValidation Loss: 23.090124\n",
-      "Validation loss decreased (24.203169 --> 23.090124).  Saving model ...\n",
-      "Epoch: 9 \tTraining Loss: 21.043298 \tValidation Loss: 22.905686\n",
-      "Validation loss decreased (23.090124 --> 22.905686).  Saving model ...\n",
-      "Epoch: 10 \tTraining Loss: 20.250996 \tValidation Loss: 23.170775\n",
-      "Epoch: 11 \tTraining Loss: 19.621161 \tValidation Loss: 22.586260\n",
-      "Validation loss decreased (22.905686 --> 22.586260).  Saving model ...\n",
-      "Epoch: 12 \tTraining Loss: 18.967947 \tValidation Loss: 22.084914\n",
-      "Validation loss decreased (22.586260 --> 22.084914).  Saving model ...\n",
-      "Epoch: 13 \tTraining Loss: 18.294242 \tValidation Loss: 22.170993\n",
-      "Epoch: 14 \tTraining Loss: 17.742965 \tValidation Loss: 22.235334\n",
-      "Epoch: 15 \tTraining Loss: 17.191995 \tValidation Loss: 22.188067\n",
+      "Epoch: 0 \tTraining Loss: 41.928474 \tValidation Loss: 36.149142\n",
+      "Validation loss decreased (inf --> 36.149142).  Saving model ...\n",
+      "Epoch: 1 \tTraining Loss: 33.406021 \tValidation Loss: 31.535990\n",
+      "Validation loss decreased (36.149142 --> 31.535990).  Saving model ...\n",
+      "Epoch: 2 \tTraining Loss: 29.986645 \tValidation Loss: 29.026595\n",
+      "Validation loss decreased (31.535990 --> 29.026595).  Saving model ...\n",
+      "Epoch: 3 \tTraining Loss: 27.796806 \tValidation Loss: 28.266311\n",
+      "Validation loss decreased (29.026595 --> 28.266311).  Saving model ...\n",
+      "Epoch: 4 \tTraining Loss: 26.246026 \tValidation Loss: 26.360779\n",
+      "Validation loss decreased (28.266311 --> 26.360779).  Saving model ...\n",
+      "Epoch: 5 \tTraining Loss: 24.983256 \tValidation Loss: 25.554680\n",
+      "Validation loss decreased (26.360779 --> 25.554680).  Saving model ...\n",
+      "Epoch: 6 \tTraining Loss: 23.907650 \tValidation Loss: 24.931439\n",
+      "Validation loss decreased (25.554680 --> 24.931439).  Saving model ...\n",
+      "Epoch: 7 \tTraining Loss: 22.924509 \tValidation Loss: 24.198110\n",
+      "Validation loss decreased (24.931439 --> 24.198110).  Saving model ...\n",
+      "Epoch: 8 \tTraining Loss: 22.127981 \tValidation Loss: 23.928127\n",
+      "Validation loss decreased (24.198110 --> 23.928127).  Saving model ...\n",
+      "Epoch: 9 \tTraining Loss: 21.372160 \tValidation Loss: 23.599568\n",
+      "Validation loss decreased (23.928127 --> 23.599568).  Saving model ...\n",
+      "Epoch: 10 \tTraining Loss: 20.641422 \tValidation Loss: 23.933316\n",
+      "Epoch: 11 \tTraining Loss: 19.919589 \tValidation Loss: 22.942826\n",
+      "Validation loss decreased (23.599568 --> 22.942826).  Saving model ...\n",
+      "Epoch: 12 \tTraining Loss: 19.266797 \tValidation Loss: 22.437601\n",
+      "Validation loss decreased (22.942826 --> 22.437601).  Saving model ...\n",
+      "Epoch: 13 \tTraining Loss: 18.638509 \tValidation Loss: 22.854837\n",
+      "Epoch: 14 \tTraining Loss: 18.039349 \tValidation Loss: 23.089719\n",
+      "Epoch: 15 \tTraining Loss: 17.516967 \tValidation Loss: 22.803356\n",
       "Early stopping after 15 epochss.\n"
      ]
     }
@@ -496,15 +505,14 @@
    "metadata": {},
    "outputs": [
     {
-     "ename": "NameError",
-     "evalue": "name 'train_loss_list' is not defined",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
-      "Cell \u001b[1;32mIn[9], line 3\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[1;32m----> 3\u001b[0m plt\u001b[38;5;241m.\u001b[39mplot(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(train_loss_list)), train_loss_list, label\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain_loss\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m      4\u001b[0m plt\u001b[38;5;241m.\u001b[39mplot(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(valid_loss_list)), valid_loss_list, label\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_loss\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m      5\u001b[0m plt\u001b[38;5;241m.\u001b[39mxlabel(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
-      "\u001b[1;31mNameError\u001b[0m: name 'train_loss_list' is not defined"
-     ]
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
     }
    ],
    "source": [
@@ -529,7 +537,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 51,
+   "execution_count": 11,
    "id": "e93efdfc",
    "metadata": {},
    "outputs": [
@@ -537,20 +545,20 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Test Loss: 21.919417\n",
+      "Test Loss: 22.254374\n",
       "\n",
       "Test Accuracy of airplane: 67% (679/1000)\n",
-      "Test Accuracy of automobile: 78% (782/1000)\n",
-      "Test Accuracy of  bird: 59% (596/1000)\n",
-      "Test Accuracy of   cat: 39% (398/1000)\n",
-      "Test Accuracy of  deer: 55% (554/1000)\n",
-      "Test Accuracy of   dog: 48% (482/1000)\n",
-      "Test Accuracy of  frog: 67% (678/1000)\n",
-      "Test Accuracy of horse: 60% (605/1000)\n",
-      "Test Accuracy of  ship: 71% (713/1000)\n",
-      "Test Accuracy of truck: 68% (687/1000)\n",
+      "Test Accuracy of automobile: 75% (753/1000)\n",
+      "Test Accuracy of  bird: 37% (377/1000)\n",
+      "Test Accuracy of   cat: 35% (356/1000)\n",
+      "Test Accuracy of  deer: 63% (638/1000)\n",
+      "Test Accuracy of   dog: 46% (465/1000)\n",
+      "Test Accuracy of  frog: 73% (732/1000)\n",
+      "Test Accuracy of horse: 69% (691/1000)\n",
+      "Test Accuracy of  ship: 73% (734/1000)\n",
+      "Test Accuracy of truck: 66% (663/1000)\n",
       "\n",
-      "Test Accuracy (Overall): 61% (6174/10000)\n"
+      "Test Accuracy (Overall): 60% (6088/10000)\n"
      ]
     }
    ],
@@ -638,7 +646,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 31,
+   "execution_count": 13,
    "id": "43fff7d9",
    "metadata": {},
    "outputs": [
@@ -686,11 +694,11 @@
     "\n",
     "\n",
     "# create a complete CNN\n",
-    "model2 = Net()\n",
-    "print(model2)\n",
+    "model = Net()\n",
+    "print(model)\n",
     "# move tensors to GPU if CUDA is available\n",
     "if train_on_gpu:\n",
-    "    model2.cuda()"
+    "    model.cuda()"
    ]
   },
   {
@@ -703,7 +711,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 32,
+   "execution_count": 14,
    "id": "40638ce8",
    "metadata": {},
    "outputs": [
@@ -711,50 +719,54 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Epoch: 0 \tTraining Loss: 45.992308 \tValidation Loss: 45.734079\n",
-      "Validation loss decreased (inf --> 45.734079).  Saving model ...\n",
-      "Epoch: 1 \tTraining Loss: 42.861547 \tValidation Loss: 39.263893\n",
-      "Validation loss decreased (45.734079 --> 39.263893).  Saving model ...\n",
-      "Epoch: 2 \tTraining Loss: 36.771661 \tValidation Loss: 34.267659\n",
-      "Validation loss decreased (39.263893 --> 34.267659).  Saving model ...\n",
-      "Epoch: 3 \tTraining Loss: 33.214925 \tValidation Loss: 32.355083\n",
-      "Validation loss decreased (34.267659 --> 32.355083).  Saving model ...\n",
-      "Epoch: 4 \tTraining Loss: 31.095436 \tValidation Loss: 30.169813\n",
-      "Validation loss decreased (32.355083 --> 30.169813).  Saving model ...\n",
-      "Epoch: 5 \tTraining Loss: 29.427925 \tValidation Loss: 28.723804\n",
-      "Validation loss decreased (30.169813 --> 28.723804).  Saving model ...\n",
-      "Epoch: 6 \tTraining Loss: 27.979727 \tValidation Loss: 27.368319\n",
-      "Validation loss decreased (28.723804 --> 27.368319).  Saving model ...\n",
-      "Epoch: 7 \tTraining Loss: 26.560045 \tValidation Loss: 26.531596\n",
-      "Validation loss decreased (27.368319 --> 26.531596).  Saving model ...\n",
-      "Epoch: 8 \tTraining Loss: 25.233062 \tValidation Loss: 25.151196\n",
-      "Validation loss decreased (26.531596 --> 25.151196).  Saving model ...\n",
-      "Epoch: 9 \tTraining Loss: 23.933243 \tValidation Loss: 24.118065\n",
-      "Validation loss decreased (25.151196 --> 24.118065).  Saving model ...\n",
-      "Epoch: 10 \tTraining Loss: 22.824302 \tValidation Loss: 23.136555\n",
-      "Validation loss decreased (24.118065 --> 23.136555).  Saving model ...\n",
-      "Epoch: 11 \tTraining Loss: 21.638535 \tValidation Loss: 22.276559\n",
-      "Validation loss decreased (23.136555 --> 22.276559).  Saving model ...\n",
-      "Epoch: 12 \tTraining Loss: 20.643672 \tValidation Loss: 21.490677\n",
-      "Validation loss decreased (22.276559 --> 21.490677).  Saving model ...\n",
-      "Epoch: 13 \tTraining Loss: 19.723929 \tValidation Loss: 20.878862\n",
-      "Validation loss decreased (21.490677 --> 20.878862).  Saving model ...\n",
-      "Epoch: 14 \tTraining Loss: 18.881656 \tValidation Loss: 20.151909\n",
-      "Validation loss decreased (20.878862 --> 20.151909).  Saving model ...\n",
-      "Epoch: 15 \tTraining Loss: 18.077398 \tValidation Loss: 22.761932\n",
-      "Epoch: 16 \tTraining Loss: 17.244630 \tValidation Loss: 20.172645\n",
-      "Epoch: 17 \tTraining Loss: 16.708238 \tValidation Loss: 19.282629\n",
-      "Validation loss decreased (20.151909 --> 19.282629).  Saving model ...\n",
-      "Epoch: 18 \tTraining Loss: 16.049521 \tValidation Loss: 19.141060\n",
-      "Validation loss decreased (19.282629 --> 19.141060).  Saving model ...\n",
-      "Epoch: 19 \tTraining Loss: 15.306451 \tValidation Loss: 18.852022\n",
-      "Validation loss decreased (19.141060 --> 18.852022).  Saving model ...\n",
-      "Epoch: 20 \tTraining Loss: 14.730358 \tValidation Loss: 18.173483\n",
-      "Validation loss decreased (18.852022 --> 18.173483).  Saving model ...\n",
-      "Epoch: 21 \tTraining Loss: 14.233885 \tValidation Loss: 18.950112\n",
-      "Epoch: 22 \tTraining Loss: 13.598710 \tValidation Loss: 19.136868\n",
-      "Epoch: 23 \tTraining Loss: 13.122754 \tValidation Loss: 18.415266\n",
-      "Early stopping after 23 epochss.\n"
+      "Epoch: 0 \tTraining Loss: 45.923728 \tValidation Loss: 45.169865\n",
+      "Validation loss decreased (inf --> 45.169865).  Saving model ...\n",
+      "Epoch: 1 \tTraining Loss: 41.546925 \tValidation Loss: 38.793506\n",
+      "Validation loss decreased (45.169865 --> 38.793506).  Saving model ...\n",
+      "Epoch: 2 \tTraining Loss: 36.476450 \tValidation Loss: 35.295533\n",
+      "Validation loss decreased (38.793506 --> 35.295533).  Saving model ...\n",
+      "Epoch: 3 \tTraining Loss: 33.383830 \tValidation Loss: 33.026579\n",
+      "Validation loss decreased (35.295533 --> 33.026579).  Saving model ...\n",
+      "Epoch: 4 \tTraining Loss: 31.333926 \tValidation Loss: 30.475611\n",
+      "Validation loss decreased (33.026579 --> 30.475611).  Saving model ...\n",
+      "Epoch: 5 \tTraining Loss: 29.614357 \tValidation Loss: 29.033134\n",
+      "Validation loss decreased (30.475611 --> 29.033134).  Saving model ...\n",
+      "Epoch: 6 \tTraining Loss: 28.001916 \tValidation Loss: 28.396793\n",
+      "Validation loss decreased (29.033134 --> 28.396793).  Saving model ...\n",
+      "Epoch: 7 \tTraining Loss: 26.720634 \tValidation Loss: 26.933168\n",
+      "Validation loss decreased (28.396793 --> 26.933168).  Saving model ...\n",
+      "Epoch: 8 \tTraining Loss: 25.416131 \tValidation Loss: 25.584761\n",
+      "Validation loss decreased (26.933168 --> 25.584761).  Saving model ...\n",
+      "Epoch: 9 \tTraining Loss: 24.044020 \tValidation Loss: 25.319066\n",
+      "Validation loss decreased (25.584761 --> 25.319066).  Saving model ...\n",
+      "Epoch: 10 \tTraining Loss: 22.966510 \tValidation Loss: 24.229066\n",
+      "Validation loss decreased (25.319066 --> 24.229066).  Saving model ...\n",
+      "Epoch: 11 \tTraining Loss: 21.892974 \tValidation Loss: 23.542444\n",
+      "Validation loss decreased (24.229066 --> 23.542444).  Saving model ...\n",
+      "Epoch: 12 \tTraining Loss: 21.116193 \tValidation Loss: 22.082797\n",
+      "Validation loss decreased (23.542444 --> 22.082797).  Saving model ...\n",
+      "Epoch: 13 \tTraining Loss: 20.105640 \tValidation Loss: 21.183286\n",
+      "Validation loss decreased (22.082797 --> 21.183286).  Saving model ...\n",
+      "Epoch: 14 \tTraining Loss: 19.280560 \tValidation Loss: 21.298658\n",
+      "Epoch: 15 \tTraining Loss: 18.506203 \tValidation Loss: 20.982773\n",
+      "Validation loss decreased (21.183286 --> 20.982773).  Saving model ...\n",
+      "Epoch: 16 \tTraining Loss: 17.824514 \tValidation Loss: 20.400616\n",
+      "Validation loss decreased (20.982773 --> 20.400616).  Saving model ...\n",
+      "Epoch: 17 \tTraining Loss: 17.116147 \tValidation Loss: 19.853483\n",
+      "Validation loss decreased (20.400616 --> 19.853483).  Saving model ...\n",
+      "Epoch: 18 \tTraining Loss: 16.443002 \tValidation Loss: 20.076364\n",
+      "Epoch: 19 \tTraining Loss: 15.684523 \tValidation Loss: 19.175958\n",
+      "Validation loss decreased (19.853483 --> 19.175958).  Saving model ...\n",
+      "Epoch: 20 \tTraining Loss: 15.149510 \tValidation Loss: 19.180701\n",
+      "Epoch: 21 \tTraining Loss: 14.661278 \tValidation Loss: 18.647935\n",
+      "Validation loss decreased (19.175958 --> 18.647935).  Saving model ...\n",
+      "Epoch: 22 \tTraining Loss: 14.030736 \tValidation Loss: 18.935119\n",
+      "Epoch: 23 \tTraining Loss: 13.360740 \tValidation Loss: 18.588247\n",
+      "Validation loss decreased (18.647935 --> 18.588247).  Saving model ...\n",
+      "Epoch: 24 \tTraining Loss: 12.932477 \tValidation Loss: 19.555995\n",
+      "Epoch: 25 \tTraining Loss: 12.393839 \tValidation Loss: 19.220007\n",
+      "Epoch: 26 \tTraining Loss: 11.902525 \tValidation Loss: 19.395309\n",
+      "Early stopping after 26 epochss.\n"
      ]
     }
    ],
@@ -762,7 +774,7 @@
     "import torch.optim as optim\n",
     "\n",
     "criterion = nn.CrossEntropyLoss()  # specify loss function\n",
-    "optimizer = optim.SGD(model2.parameters(), lr=0.01)  # specify optimizer\n",
+    "optimizer = optim.SGD(model.parameters(), lr=0.01)  # specify optimizer\n",
     "\n",
     "n_epochs = 30  # number of epochs to train the model\n",
     "train_loss_list = []  # list to store loss to visualize\n",
@@ -777,7 +789,7 @@
     "    valid_loss = 0.0\n",
     "\n",
     "    # Train the model\n",
-    "    model2.train()\n",
+    "    model.train()\n",
     "    for data, target in train_loader:\n",
     "        # Move tensors to GPU if CUDA is available\n",
     "        if train_on_gpu:\n",
@@ -785,7 +797,7 @@
     "        # Clear the gradients of all optimized variables\n",
     "        optimizer.zero_grad()\n",
     "        # Forward pass: compute predicted outputs by passing inputs to the model\n",
-    "        output = model2(data)\n",
+    "        output = model(data)\n",
     "        # Calculate the batch loss\n",
     "        loss = criterion(output, target)\n",
     "        # Backward pass: compute gradient of the loss with respect to model parameters\n",
@@ -796,13 +808,13 @@
     "        train_loss += loss.item() * data.size(0)\n",
     "\n",
     "    # Validate the model\n",
-    "    model2.eval()\n",
+    "    model.eval()\n",
     "    for data, target in valid_loader:\n",
     "        # Move tensors to GPU if CUDA is available\n",
     "        if train_on_gpu:\n",
     "            data, target = data.cuda(), target.cuda()\n",
     "        # Forward pass: compute predicted outputs by passing inputs to the model\n",
-    "        output = model2(data)\n",
+    "        output = model(data)\n",
     "        # Calculate the batch loss\n",
     "        loss = criterion(output, target)\n",
     "        # Update average validation loss\n",
@@ -829,7 +841,7 @@
     "                valid_loss_min, valid_loss\n",
     "            )\n",
     "        )\n",
-    "        torch.save(model2.state_dict(), \"model_cifar_exo1.pt\")\n",
+    "        torch.save(model.state_dict(), \"model_cifar_exo1.pt\")\n",
     "        valid_loss_min = valid_loss\n",
     "        patience_counter = 0\n",
     "    else:\n",
@@ -850,13 +862,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 33,
+   "execution_count": 15,
    "id": "206bc2a1",
    "metadata": {},
    "outputs": [
     {
      "data": {
-      "image/png": "",
+      "image/png": "",
       "text/plain": [
        "<Figure size 640x480 with 1 Axes>"
       ]
@@ -877,38 +889,47 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 16,
    "id": "b0fbfa80",
    "metadata": {},
    "outputs": [
     {
-     "ename": "NameError",
-     "evalue": "name 'model2' is not defined",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
-      "Cell \u001b[1;32mIn[10], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m model2\u001b[38;5;241m.\u001b[39mload_state_dict(torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./model_cifar_exo1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[0;32m      3\u001b[0m \u001b[38;5;66;03m# track test loss\u001b[39;00m\n\u001b[0;32m      4\u001b[0m test_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.0\u001b[39m\n",
-      "\u001b[1;31mNameError\u001b[0m: name 'model2' is not defined"
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Test Loss: 19.006048\n",
+      "\n",
+      "Test Accuracy of airplane: 72% (729/1000)\n",
+      "Test Accuracy of automobile: 79% (791/1000)\n",
+      "Test Accuracy of  bird: 60% (601/1000)\n",
+      "Test Accuracy of   cat: 48% (484/1000)\n",
+      "Test Accuracy of  deer: 64% (641/1000)\n",
+      "Test Accuracy of   dog: 61% (610/1000)\n",
+      "Test Accuracy of  frog: 72% (727/1000)\n",
+      "Test Accuracy of horse: 71% (716/1000)\n",
+      "Test Accuracy of  ship: 87% (875/1000)\n",
+      "Test Accuracy of truck: 77% (775/1000)\n",
+      "\n",
+      "Test Accuracy (Overall): 69% (6949/10000)\n"
      ]
     }
    ],
    "source": [
-    "model2.load_state_dict(torch.load(\"./model_cifar_exo1.pt\"))\n",
+    "model.load_state_dict(torch.load(\"./model_cifar_exo1.pt\"))\n",
     "\n",
     "# track test loss\n",
     "test_loss = 0.0\n",
     "class_correct2 = list(0.0 for i in range(10))\n",
     "class_total2 = list(0.0 for i in range(10))\n",
     "\n",
-    "model2.eval()\n",
+    "model.eval()\n",
     "# iterate over test data\n",
     "for data, target in test_loader:\n",
     "    # move tensors to GPU if CUDA is available\n",
     "    if train_on_gpu:\n",
     "        data, target = data.cuda(), target.cuda()\n",
     "    # forward pass: compute predicted outputs by passing inputs to the model\n",
-    "    output = model2(data)\n",
+    "    output = model(data)\n",
     "    # calculate the batch loss\n",
     "    loss = criterion(output, target)\n",
     "    # update test loss\n",
@@ -1043,6 +1064,8 @@
    "source": [
     "import os\n",
     "\n",
+    "model = Net()\n",
+    "model.load_state_dict(torch.load(\"./model_cifar_exo1.pt\"))\n",
     "\n",
     "def print_size_of_model(model, label=\"\"):\n",
     "    torch.save(model.state_dict(), \"temp.p\")\n",
@@ -1221,7 +1244,7 @@
     "\n",
     "fig, ax = plt.subplots()\n",
     "r1 = ax.bar(x - width/2, accuracy_initial, width, label=\"Initial Model\")\n",
-    "r2 = ax.bar(x + width/2, accuracy_quant, width, label=\"New Model\")\n",
+    "r2 = ax.bar(x + width/2, accuracy_quant, width, label=\"Quantized Model\")\n",
     "\n",
     "ax.set_xlabel('Classes')\n",
     "ax.set_ylabel('Accuracy')\n",
-- 
GitLab