From c8f84face87c06c6520292591b6ac8fb7278772c Mon Sep 17 00:00:00 2001
From: HU <franck.hu@ecl20.ec-lyon.fr>
Date: Fri, 1 Dec 2023 17:22:37 +0100
Subject: [PATCH] maj

---
 TD2 Deep Learning.ipynb | 102 ++++++++++++++++++++++++++++++++++++++--
 1 file changed, 99 insertions(+), 3 deletions(-)

diff --git a/TD2 Deep Learning.ipynb b/TD2 Deep Learning.ipynb
index e581032..ec4ce7d 100644
--- a/TD2 Deep Learning.ipynb	
+++ b/TD2 Deep Learning.ipynb	
@@ -1871,6 +1871,18 @@
     "Apply ther quantization (post and quantization aware) and evaluate impact on model size and accuracy."
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "bb08304c",
+   "metadata": {},
+   "source": [
+    "_________\n",
+    "\n",
+    "Dans cette partie, nous allons ajouter un dataset de test. On configure le data_transform pour qu'il prennent en compte les données test contenu dans le dossier test.\n",
+    "\n",
+    "On va aussi configurer la sortie avec la nouvelle structure de neurones."
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 42,
@@ -1988,6 +2000,8 @@
     "\n",
     "# Data augmentation and normalization for training\n",
     "# Just normalization for validation\n",
+    "\n",
+    "# on ajoute les data test\n",
     "data_transforms = {\n",
     "    \"train\": transforms.Compose(\n",
     "        [\n",
@@ -2150,6 +2164,8 @@
     "# Replace the final fully connected layer\n",
     "# Parameters of newly constructed modules have requires_grad=True by default\n",
     "num_ftrs = model.fc.in_features\n",
+    "\n",
+    "# on ajoute les nouvelles couches\n",
     "model.fc = nn.Sequential(\n",
     "    nn.Linear(num_ftrs, 256),  # Add a layer with 256 neurons\n",
     "    nn.ReLU(),                 # Apply ReLU activation function\n",
@@ -2173,6 +2189,14 @@
     "\n"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "ec105406",
+   "metadata": {},
+   "source": [
+    "On a entraîné ce nouveau modèle. on définit maintenant la fonction eval_model(model), qui va évaluer les performances du modèle sur le data test."
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "8e2bd8fd",
@@ -2211,16 +2235,16 @@
     }
    ],
    "source": [
-    "\n",
     "def eval_model(model):\n",
-    "\n",
+    "    # initialisation des variables\n",
     "    running_loss = 0.0\n",
     "    running_corrects = 0\n",
+    "    \n",
+    "    #on teste sur les données test\n",
     "    for inputs, labels in dataloaders[\"test\"]:\n",
     "        inputs = inputs.to(device)\n",
     "        labels = labels.to(device)\n",
     "\n",
-    "\n",
     "        outputs = model(inputs)\n",
     "        _, preds = torch.max(outputs, 1)\n",
     "        loss = criterion(outputs, labels)\n",
@@ -2229,6 +2253,7 @@
     "        running_loss += loss.item() * inputs.size(0)\n",
     "        running_corrects += torch.sum(preds == labels.data)\n",
     "\n",
+    "    #Calcul du loss et de l'accuracy\n",
     "    loss = running_loss / dataset_sizes[\"test\"]\n",
     "    acc = running_corrects.double() / dataset_sizes[\"test\"]\n",
     "        \n",
@@ -2238,6 +2263,77 @@
     "print(\"Loss: {:.4f} Acc: {:.4f}\".format(loss, acc))"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "c01bb10c",
+   "metadata": {},
+   "source": [
+    "On obtient une précision de ??????.\n",
+    "\n",
+    "_________\n",
+    "\n",
+    "On va maintenant quantizer le modèle"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 47,
+   "id": "0d8116cf",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "model:  int8  \t Size (KB): 45304.25\n",
+      "model:  int8  \t Size (KB): 44911.014\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/Users/franck/Library/Python/3.8/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n",
+      "  warnings.warn(\n",
+      "/Users/franck/Library/Python/3.8/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n",
+      "  warnings.warn(\n",
+      "/Users/franck/Library/Python/3.8/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n",
+      "  warnings.warn(\n",
+      "/Users/franck/Library/Python/3.8/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n",
+      "  warnings.warn(\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Loss: 0.2149 Acc: 1.0000\n"
+     ]
+    }
+   ],
+   "source": [
+    "quantized_model = torch.quantization.quantize_dynamic(model, dtype=torch.qint8)\n",
+    "print_size_of_model(model, \"int8\")\n",
+    "\n",
+    "print_size_of_model(quantized_model, \"int8\")\n",
+    "\n",
+    "\n",
+    "\n",
+    "loss,acc=eval_model(quantized_model)\n",
+    "print(\"Loss: {:.4f} Acc: {:.4f}\".format(loss, acc))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "913c3dae",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "loss,acc=eval_model(quantized_model)\n",
+    "print(\"Loss: {:.4f} Acc: {:.4f}\".format(loss, acc))"
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "04a263f0",
-- 
GitLab