From c8f84face87c06c6520292591b6ac8fb7278772c Mon Sep 17 00:00:00 2001 From: HU <franck.hu@ecl20.ec-lyon.fr> Date: Fri, 1 Dec 2023 17:22:37 +0100 Subject: [PATCH] maj --- TD2 Deep Learning.ipynb | 102 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 99 insertions(+), 3 deletions(-) diff --git a/TD2 Deep Learning.ipynb b/TD2 Deep Learning.ipynb index e581032..ec4ce7d 100644 --- a/TD2 Deep Learning.ipynb +++ b/TD2 Deep Learning.ipynb @@ -1871,6 +1871,18 @@ "Apply ther quantization (post and quantization aware) and evaluate impact on model size and accuracy." ] }, + { + "cell_type": "markdown", + "id": "bb08304c", + "metadata": {}, + "source": [ + "_________\n", + "\n", + "Dans cette partie, nous allons ajouter un dataset de test. On configure le data_transform pour qu'il prennent en compte les données test contenu dans le dossier test.\n", + "\n", + "On va aussi configurer la sortie avec la nouvelle structure de neurones." + ] + }, { "cell_type": "code", "execution_count": 42, @@ -1988,6 +2000,8 @@ "\n", "# Data augmentation and normalization for training\n", "# Just normalization for validation\n", + "\n", + "# on ajoute les data test\n", "data_transforms = {\n", " \"train\": transforms.Compose(\n", " [\n", @@ -2150,6 +2164,8 @@ "# Replace the final fully connected layer\n", "# Parameters of newly constructed modules have requires_grad=True by default\n", "num_ftrs = model.fc.in_features\n", + "\n", + "# on ajoute les nouvelles couches\n", "model.fc = nn.Sequential(\n", " nn.Linear(num_ftrs, 256), # Add a layer with 256 neurons\n", " nn.ReLU(), # Apply ReLU activation function\n", @@ -2173,6 +2189,14 @@ "\n" ] }, + { + "cell_type": "markdown", + "id": "ec105406", + "metadata": {}, + "source": [ + "On a entraîné ce nouveau modèle. on définit maintenant la fonction eval_model(model), qui va évaluer les performances du modèle sur le data test." + ] + }, { "cell_type": "markdown", "id": "8e2bd8fd", @@ -2211,16 +2235,16 @@ } ], "source": [ - "\n", "def eval_model(model):\n", - "\n", + " # initialisation des variables\n", " running_loss = 0.0\n", " running_corrects = 0\n", + " \n", + " #on teste sur les données test\n", " for inputs, labels in dataloaders[\"test\"]:\n", " inputs = inputs.to(device)\n", " labels = labels.to(device)\n", "\n", - "\n", " outputs = model(inputs)\n", " _, preds = torch.max(outputs, 1)\n", " loss = criterion(outputs, labels)\n", @@ -2229,6 +2253,7 @@ " running_loss += loss.item() * inputs.size(0)\n", " running_corrects += torch.sum(preds == labels.data)\n", "\n", + " #Calcul du loss et de l'accuracy\n", " loss = running_loss / dataset_sizes[\"test\"]\n", " acc = running_corrects.double() / dataset_sizes[\"test\"]\n", " \n", @@ -2238,6 +2263,77 @@ "print(\"Loss: {:.4f} Acc: {:.4f}\".format(loss, acc))" ] }, + { + "cell_type": "markdown", + "id": "c01bb10c", + "metadata": {}, + "source": [ + "On obtient une précision de ??????.\n", + "\n", + "_________\n", + "\n", + "On va maintenant quantizer le modèle" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "0d8116cf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model: int8 \t Size (KB): 45304.25\n", + "model: int8 \t Size (KB): 44911.014\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/franck/Library/Python/3.8/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", + " warnings.warn(\n", + "/Users/franck/Library/Python/3.8/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", + " warnings.warn(\n", + "/Users/franck/Library/Python/3.8/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", + " warnings.warn(\n", + "/Users/franck/Library/Python/3.8/lib/python/site-packages/urllib3/__init__.py:34: NotOpenSSLWarning: urllib3 v2.0 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loss: 0.2149 Acc: 1.0000\n" + ] + } + ], + "source": [ + "quantized_model = torch.quantization.quantize_dynamic(model, dtype=torch.qint8)\n", + "print_size_of_model(model, \"int8\")\n", + "\n", + "print_size_of_model(quantized_model, \"int8\")\n", + "\n", + "\n", + "\n", + "loss,acc=eval_model(quantized_model)\n", + "print(\"Loss: {:.4f} Acc: {:.4f}\".format(loss, acc))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "913c3dae", + "metadata": {}, + "outputs": [], + "source": [ + "loss,acc=eval_model(quantized_model)\n", + "print(\"Loss: {:.4f} Acc: {:.4f}\".format(loss, acc))" + ] + }, { "cell_type": "markdown", "id": "04a263f0", -- GitLab