diff --git a/TD3_Vision_Transformer_rendu.ipynb b/TD3_Vision_Transformer_rendu.ipynb index 3bc2ff3a016d1baf0add590dfd9905f47b59e68e..89c3c59f9839d4955feb2a36abae64f989432284 100644 --- a/TD3_Vision_Transformer_rendu.ipynb +++ b/TD3_Vision_Transformer_rendu.ipynb @@ -125,21 +125,7 @@ }, { "cell_type": "code", - "source": [ - "import matplotlib.pyplot as plt\n", - "import random\n", - "from torchvision.transforms import ToPILImage\n", - "\n", - "to_pil = ToPILImage()\n", - "random_index = random.randint(0, len(train_set) - 1)\n", - "image, label = train_set[random_index]\n", - "\n", - "image_pil = to_pil(image)\n", - "plt.imshow(image_pil, cmap='gray')\n", - "plt.title(f\"Classe : {label}\")\n", - "plt.axis('off')\n", - "plt.show()" - ], + "execution_count": 117, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -148,18 +134,32 @@ "id": "m6l69-vHAekF", "outputId": "65e9e2e0-e5bf-449c-d0e6-d13c4280519c" }, - "execution_count": 117, "outputs": [ { - "output_type": "display_data", "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAARMUlEQVR4nO3cf4zXBf3A8dd5J+fB3am7BLzDQSBrcNlZ/ljzlCsNBdHcNKnW4kdquNGUiqvWMo0/XEpGjLIw2+wH9MO12gy1HVltSQPKfqBTM9AgmLemAaJAxL2/fzRfXwmUz/vj/SDv8djY9PN5vz7v14fNe/r+3N27piiKIgAgIo4b6gUAOHaIAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAsesCRMmxLx584Z6DRhWRIFBt3nz5liwYEFMnDgxTjjhhGhubo7Ozs5Yvnx57N27d6jXGzK9vb2xYMGCaGtrixNOOCEmTJgQ11xzzVCvxTBTN9QLMLysWbMmrr766qivr485c+bEW9/61vjXv/4Vv/nNb6K7uzsee+yxuOuuu4Z6zUG3bdu26OzsjIiI66+/Ptra2mLHjh2xYcOGId6M4UYUGDRPP/10fOADH4jx48fHQw89FKeeemo+t3DhwvjrX/8aa9asGcINh86CBQuirq4uNm7cGC0tLUO9DsOYj48YNLfffnvs2bMnvvWtbx0ShJedfvrpceONN77q/PPPPx+LFy+OM844IxobG6O5uTlmzpwZf/rTnw47dsWKFdHe3h4jR46Mk08+Oc4+++xYvXp1Pv/CCy/EokWLYsKECVFfXx+jR4+O6dOnxyOPPHLI66xfvz5mzJgRJ554YowcOTK6urri4Ycfruj9bt26NZ544omjHvfEE0/EAw88EN3d3dHS0hL79u2LAwcOVHQO6G+iwKC57777YuLEiXHeeedVNb9ly5b46U9/Gpdddll8+ctfju7u7ti0aVN0dXXFjh078rhvfvObccMNN8TUqVPjK1/5SnzhC1+IM888M9avX5/HXH/99fH1r389rrrqqrjzzjtj8eLF0dDQEI8//nge89BDD8W0adNi9+7dcfPNN8ett94aO3fujAsvvLCij3XmzJkTU6ZMOepxa9eujYiIMWPGxEUXXRQNDQ3R0NAQM2fOjGeeeabE3xD0gwIGwa5du4qIKK644oqKZ8aPH1/MnTs3/33fvn3FwYMHDznm6aefLurr64slS5bkY1dccUXR3t7+mq994oknFgsXLnzV5/v6+orJkycXl1xySdHX15ePv/TSS8Wb3/zmYvr06Ufdv6urq6jkP7EbbrihiIiipaWlmDFjRvHDH/6wWLp0adHY2FhMmjSpePHFF4/6GtBffE+BQbF79+6IiGhqaqr6Nerr6/OfDx48GDt37ozGxsZ4y1vecsjHPieddFL8/e9/j40bN8Y555xzxNc66aSTYv369bFjx45obW097Pk//vGP8dRTT8XnPve5eO655w557qKLLorvfve70dfXF8cd9+oX27/61a8qel979uyJiIixY8fGmjVr8jXHjRsXH/zgB2P16tVx7bXXVvRa8Hr5+IhB0dzcHBH/+Sy/Wn19fbFs2bKYPHly1NfXx5ve9KY45ZRT4s9//nPs2rUrj/v0pz8djY2Nce6558bkyZNj4cKFh30f4Pbbb49HH300TjvttDj33HPjlltuiS1btuTzTz31VEREzJ07N0455ZRD/tx9992xf//+Q875ejQ0NERExOzZsw+JzNVXXx11dXWxbt26fjkPVEIUGBTNzc3R2toajz76aNWvceutt8YnPvGJmDZtWnzve9+Ln//859HT0xPt7e3R19eXx02ZMiWefPLJ+MEPfhDnn39+/PjHP47zzz8/br755jxm9uzZsWXLllixYkW0trbG0qVLo729PR544IGIiHy9pUuXRk9PzxH/NDY2Vv1eXunlK5UxY8Yc8nhtbW20tLTEP//5z345D1RkqD+/Yvj46Ec/WkREsW7duoqO/+/vKXR0dBTvfve7Dzuura2t6OrqetXX2b9/fzFr1qyitra22Lt37xGP6e3tLdra2orOzs6iKIpiw4YNRUQUK1eurGjX1+PBBx8sIqK46aabDtu7tra2uO666wZ8B3iZKwUGzac+9akYNWpUXHvttdHb23vY85s3b47ly5e/6nxtbW0URXHIY/fee29s3779kMf++3sAI0aMiKlTp0ZRFHHgwIE4ePDgYR/9jB49OlpbW2P//v0REXHWWWfFpEmT4ktf+lJ+5v9K//jHP177zUblP5L6rne9K0aPHh2rVq2Kffv25eP33HNPHDx4MKZPn37U14D+4hvNDJpJkybF6tWr4/3vf39MmTLlkN9oXrduXdx7772vea+jyy67LJYsWRLz58+P8847LzZt2hSrVq2KiRMnHnLcxRdfHGPHjo3Ozs4YM2ZMPP744/HVr341Zs2aFU1NTbFz584YN25cvO9974uOjo5obGyMtWvXxsaNG+OOO+6IiIjjjjsu7r777pg5c2a0t7fH/Pnzo62tLbZv3x6//OUvo7m5Oe67777XfL9z5syJX//614eF7L/V19fH0qVLY+7cuTFt2rT48Ic/HFu3bo3ly5fHBRdcEFdeeWVlf8HQH4b4SoVh6C9/+Utx3XXXFRMmTChGjBhRNDU1FZ2dncWKFSuKffv25XFH+pHUT37yk8Wpp55aNDQ0FJ2dncVvf/vboqur65CPj1auXFlMmzataGlpKerr64tJkyYV3d3dxa5du4qi+M/HMt3d3UVHR0fR1NRUjBo1qujo6CjuvPPOw3b9wx/+UFx55ZX5WuPHjy9mz55d/OIXvzjq+6z0R1Jf9v3vf7/o6Ogo6uvrizFjxhQf+9jHit27d1c8D/2hpiiO8r8xAAwbvqcAQBIFAJIoAJBEAYAkCgAkUQAgVfzLazU1NQO5BwADrJLfQHClAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAECqG+oF4Fjxnve8p/RMT09P6ZlFixaVnomIWLFiRemZvr6+qs7F8OVKAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASDVFURQVHVhTM9C7QL+ZOnVq6Zn777+/9Mxpp51WeqZara2tpWd6e3sHYBP+V1Xy5d6VAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAUt1QLwBHc/zxx5eeue2220rPDObN7eBY5UoBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgDJDfE45p111lmlZy699NIB2ATe+FwpAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAguSEeg6a2traquZtuuqmfNxlat912W1Vzzz//fD9vAodzpQBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgFRTFEVR0YE1NQO9C29wI0aMqGpu7969/bxJ/3nmmWdKz3R2dlZ1rmeffbaqOXhZJV/uXSkAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgCpbqgX4H/T2WefXXpmyZIlA7BJ/9m2bVvpmYsvvrj0jLudcixzpQBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgOSGeFTlwgsvLD1zySWXDMAm/eeee+4pPbN58+b+XwSGkCsFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgCkmqIoiooOrKkZ6F0YIu985ztLz/T09JSeGTlyZOmZam3atKn0zOWXX156Ztu2baVnYKhU8uXelQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAFLdUC/A0Fu8eHHpmcG8ud2BAwdKz3zmM58pPePmduBKAYBXEAUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACR3SX2DGT16dOmZt7/97QOwyeGKoqhq7uMf/3jpmQcffLCqc5U1duzY0jPz58+v6lyXXnppVXPHqq1bt1Y1t2zZstIzv/vd76o613DkSgGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAKmmqPAuZTU1NQO9C/3gRz/6UemZq666agA2OdwjjzxS1dw555zTz5sc2Tve8Y7SMz/5yU9Kz4wbN670DP/vhRdeKD1zxhlnlJ7Ztm1b6ZljXSVf7l0pAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAg1Q31AhzZmWeeWdXce9/73v5dpB/df//9g3auD33oQ6Vnli9fXnrm5JNPLj3D69PU1FR6ZsSIEQOwyRuTKwUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACQ3xDtGjRo1qqq5448/vp83ObIDBw6Unlm2bFlV57rjjjtKz9x4442lZ2pqakrPDKZq/s57enpKz2zYsKH0zOzZs0vPTJ06tfQMA8+VAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkhviDYJqblL32c9+dgA26T8rV64sPTNr1qyqzrVo0aKq5gZDb29v6Znf//73VZ3ri1/8YumZxx57rPTMvHnzSs8M5s3tXnrppdIz//73vwdgkzcmVwoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEByl9RBMH78+NIzM2bMGIBN+k9jY2Ppme985zsDsMmR7d+/v/TM2rVrS89cc801pWdaWlpKz0REXHDBBaVnvvGNb5SeGaw7nu7Zs6equWru0Pu3v/2tqnMNR64UAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQaoqiKCo6sKZmoHd5wzr99NNLzzz55JMDsMnw8eKLL5aeefjhhwdgk8N1dnZWNTdq1Kh+3mRoffvb365q7iMf+Ug/bzJ8VPLl3pUCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSG+INgrFjx5ae2bhxY1Xnam1trWoOXo+f/exnpWfmzJlT1bl27dpV1RxuiAdASaIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJDqhnqB4eDZZ58tPXPXXXdVda5bbrmlqjnemNatW1d6Zt68eaVnent7S8/s2bOn9AwDz5UCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQBSTVEURUUH1tQM9C68Ql1ddfcqfNvb3lZ65vOf/3zpmcsvv7z0zBvRqlWrSs9s3769qnN97WtfKz3z3HPPlZ7Zu3dv6Rn+N1Ty5d6VAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkNwlFWCYcJdUAEoRBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAINVVemBRFAO5BwDHAFcKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKAKT/A41pVEW0tQbgAAAAAElFTkSuQmCC", "text/plain": [ "<Figure size 640x480 with 1 Axes>" - ], - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAARMUlEQVR4nO3cf4zXBf3A8dd5J+fB3am7BLzDQSBrcNlZ/ljzlCsNBdHcNKnW4kdquNGUiqvWMo0/XEpGjLIw2+wH9MO12gy1HVltSQPKfqBTM9AgmLemAaJAxL2/fzRfXwmUz/vj/SDv8djY9PN5vz7v14fNe/r+3N27piiKIgAgIo4b6gUAOHaIAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAsesCRMmxLx584Z6DRhWRIFBt3nz5liwYEFMnDgxTjjhhGhubo7Ozs5Yvnx57N27d6jXGzK9vb2xYMGCaGtrixNOOCEmTJgQ11xzzVCvxTBTN9QLMLysWbMmrr766qivr485c+bEW9/61vjXv/4Vv/nNb6K7uzsee+yxuOuuu4Z6zUG3bdu26OzsjIiI66+/Ptra2mLHjh2xYcOGId6M4UYUGDRPP/10fOADH4jx48fHQw89FKeeemo+t3DhwvjrX/8aa9asGcINh86CBQuirq4uNm7cGC0tLUO9DsOYj48YNLfffnvs2bMnvvWtbx0ShJedfvrpceONN77q/PPPPx+LFy+OM844IxobG6O5uTlmzpwZf/rTnw47dsWKFdHe3h4jR46Mk08+Oc4+++xYvXp1Pv/CCy/EokWLYsKECVFfXx+jR4+O6dOnxyOPPHLI66xfvz5mzJgRJ554YowcOTK6urri4Ycfruj9bt26NZ544omjHvfEE0/EAw88EN3d3dHS0hL79u2LAwcOVHQO6G+iwKC57777YuLEiXHeeedVNb9ly5b46U9/Gpdddll8+ctfju7u7ti0aVN0dXXFjh078rhvfvObccMNN8TUqVPjK1/5SnzhC1+IM888M9avX5/HXH/99fH1r389rrrqqrjzzjtj8eLF0dDQEI8//nge89BDD8W0adNi9+7dcfPNN8ett94aO3fujAsvvLCij3XmzJkTU6ZMOepxa9eujYiIMWPGxEUXXRQNDQ3R0NAQM2fOjGeeeabE3xD0gwIGwa5du4qIKK644oqKZ8aPH1/MnTs3/33fvn3FwYMHDznm6aefLurr64slS5bkY1dccUXR3t7+mq994oknFgsXLnzV5/v6+orJkycXl1xySdHX15ePv/TSS8Wb3/zmYvr06Ufdv6urq6jkP7EbbrihiIiipaWlmDFjRvHDH/6wWLp0adHY2FhMmjSpePHFF4/6GtBffE+BQbF79+6IiGhqaqr6Nerr6/OfDx48GDt37ozGxsZ4y1vecsjHPieddFL8/e9/j40bN8Y555xzxNc66aSTYv369bFjx45obW097Pk//vGP8dRTT8XnPve5eO655w557qKLLorvfve70dfXF8cd9+oX27/61a8qel979uyJiIixY8fGmjVr8jXHjRsXH/zgB2P16tVx7bXXVvRa8Hr5+IhB0dzcHBH/+Sy/Wn19fbFs2bKYPHly1NfXx5ve9KY45ZRT4s9//nPs2rUrj/v0pz8djY2Nce6558bkyZNj4cKFh30f4Pbbb49HH300TjvttDj33HPjlltuiS1btuTzTz31VEREzJ07N0455ZRD/tx9992xf//+Q875ejQ0NERExOzZsw+JzNVXXx11dXWxbt26fjkPVEIUGBTNzc3R2toajz76aNWvceutt8YnPvGJmDZtWnzve9+Ln//859HT0xPt7e3R19eXx02ZMiWefPLJ+MEPfhDnn39+/PjHP47zzz8/br755jxm9uzZsWXLllixYkW0trbG0qVLo729PR544IGIiHy9pUuXRk9PzxH/NDY2Vv1eXunlK5UxY8Yc8nhtbW20tLTEP//5z345D1RkqD+/Yvj46Ec/WkREsW7duoqO/+/vKXR0dBTvfve7Dzuura2t6OrqetXX2b9/fzFr1qyitra22Lt37xGP6e3tLdra2orOzs6iKIpiw4YNRUQUK1eurGjX1+PBBx8sIqK46aabDtu7tra2uO666wZ8B3iZKwUGzac+9akYNWpUXHvttdHb23vY85s3b47ly5e/6nxtbW0URXHIY/fee29s3779kMf++3sAI0aMiKlTp0ZRFHHgwIE4ePDgYR/9jB49OlpbW2P//v0REXHWWWfFpEmT4ktf+lJ+5v9K//jHP177zUblP5L6rne9K0aPHh2rVq2Kffv25eP33HNPHDx4MKZPn37U14D+4hvNDJpJkybF6tWr4/3vf39MmTLlkN9oXrduXdx7772vea+jyy67LJYsWRLz58+P8847LzZt2hSrVq2KiRMnHnLcxRdfHGPHjo3Ozs4YM2ZMPP744/HVr341Zs2aFU1NTbFz584YN25cvO9974uOjo5obGyMtWvXxsaNG+OOO+6IiIjjjjsu7r777pg5c2a0t7fH/Pnzo62tLbZv3x6//OUvo7m5Oe67777XfL9z5syJX//614eF7L/V19fH0qVLY+7cuTFt2rT48Ic/HFu3bo3ly5fHBRdcEFdeeWVlf8HQH4b4SoVh6C9/+Utx3XXXFRMmTChGjBhRNDU1FZ2dncWKFSuKffv25XFH+pHUT37yk8Wpp55aNDQ0FJ2dncVvf/vboqur65CPj1auXFlMmzataGlpKerr64tJkyYV3d3dxa5du4qi+M/HMt3d3UVHR0fR1NRUjBo1qujo6CjuvPPOw3b9wx/+UFx55ZX5WuPHjy9mz55d/OIXvzjq+6z0R1Jf9v3vf7/o6Ogo6uvrizFjxhQf+9jHit27d1c8D/2hpiiO8r8xAAwbvqcAQBIFAJIoAJBEAYAkCgAkUQAgVfzLazU1NQO5BwADrJLfQHClAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAECqG+oF4Fjxnve8p/RMT09P6ZlFixaVnomIWLFiRemZvr6+qs7F8OVKAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASDVFURQVHVhTM9C7QL+ZOnVq6Zn777+/9Mxpp51WeqZara2tpWd6e3sHYBP+V1Xy5d6VAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAUt1QLwBHc/zxx5eeue2220rPDObN7eBY5UoBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgDJDfE45p111lmlZy699NIB2ATe+FwpAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAguSEeg6a2traquZtuuqmfNxlat912W1Vzzz//fD9vAodzpQBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgFRTFEVR0YE1NQO9C29wI0aMqGpu7969/bxJ/3nmmWdKz3R2dlZ1rmeffbaqOXhZJV/uXSkAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgCpbqgX4H/T2WefXXpmyZIlA7BJ/9m2bVvpmYsvvrj0jLudcixzpQBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgOSGeFTlwgsvLD1zySWXDMAm/eeee+4pPbN58+b+XwSGkCsFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgCkmqIoiooOrKkZ6F0YIu985ztLz/T09JSeGTlyZOmZam3atKn0zOWXX156Ztu2baVnYKhU8uXelQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAFLdUC/A0Fu8eHHpmcG8ud2BAwdKz3zmM58pPePmduBKAYBXEAUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACR3SX2DGT16dOmZt7/97QOwyeGKoqhq7uMf/3jpmQcffLCqc5U1duzY0jPz58+v6lyXXnppVXPHqq1bt1Y1t2zZstIzv/vd76o613DkSgGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAKmmqPAuZTU1NQO9C/3gRz/6UemZq666agA2OdwjjzxS1dw555zTz5sc2Tve8Y7SMz/5yU9Kz4wbN670DP/vhRdeKD1zxhlnlJ7Ztm1b6ZljXSVf7l0pAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAg1Q31AhzZmWeeWdXce9/73v5dpB/df//9g3auD33oQ6Vnli9fXnrm5JNPLj3D69PU1FR6ZsSIEQOwyRuTKwUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACQ3xDtGjRo1qqq5448/vp83ObIDBw6Unlm2bFlV57rjjjtKz9x4442lZ2pqakrPDKZq/s57enpKz2zYsKH0zOzZs0vPTJ06tfQMA8+VAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkhviDYJqblL32c9+dgA26T8rV64sPTNr1qyqzrVo0aKq5gZDb29v6Znf//73VZ3ri1/8YumZxx57rPTMvHnzSs8M5s3tXnrppdIz//73vwdgkzcmVwoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEByl9RBMH78+NIzM2bMGIBN+k9jY2Ppme985zsDsMmR7d+/v/TM2rVrS89cc801pWdaWlpKz0REXHDBBaVnvvGNb5SeGaw7nu7Zs6equWru0Pu3v/2tqnMNR64UAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQaoqiKCo6sKZmoHd5wzr99NNLzzz55JMDsMnw8eKLL5aeefjhhwdgk8N1dnZWNTdq1Kh+3mRoffvb365q7iMf+Ug/bzJ8VPLl3pUCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSG+INgrFjx5ae2bhxY1Xnam1trWoOXo+f/exnpWfmzJlT1bl27dpV1RxuiAdASaIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJDqhnqB4eDZZ58tPXPXXXdVda5bbrmlqjnemNatW1d6Zt68eaVnent7S8/s2bOn9AwDz5UCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQBSTVEURUUH1tQM9C68Ql1ddfcqfNvb3lZ65vOf/3zpmcsvv7z0zBvRqlWrSs9s3769qnN97WtfKz3z3HPPlZ7Zu3dv6Rn+N1Ty5d6VAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkNwlFWCYcJdUAEoRBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAINVVemBRFAO5BwDHAFcKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKAKT/A41pVEW0tQbgAAAAAElFTkSuQmCC\n" + ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import random\n", + "from torchvision.transforms import ToPILImage\n", + "\n", + "to_pil = ToPILImage()\n", + "random_index = random.randint(0, len(train_set) - 1)\n", + "image, label = train_set[random_index]\n", + "\n", + "image_pil = to_pil(image)\n", + "plt.imshow(image_pil, cmap='gray')\n", + "plt.title(f\"Classe : {label}\")\n", + "plt.axis('off')\n", + "plt.show()" ] }, { @@ -210,6 +210,27 @@ }, { "cell_type": "code", + "execution_count": 119, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 653 + }, + "id": "_7IX7zJpBJah", + "outputId": "a1365c4e-98a7-4ea4-a89b-a741ea219a40" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAn8AAAJ8CAYAAACP2sdVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAQxUlEQVR4nO3dMYic1QKG4ZnrroKmkCAiwdhopVEIsdBGEANKUFEkhWgjYm3sxSjaiJBGi3QKEQsRMY0GxE4sgtionQFDUNbCgBgIIeJ/ayGXjHpmZve+z1MP3x42h9mX02Q+TdM0AwAg4T/rPgAAAKsj/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQsrHoB+fz+TLPwQ616v8gxj3kSlZ5D91BrsR3IdvBovfQyx8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEbKz7AAA7ycGDB4dtff7558O2jhw5Mmzr7bffHrY1m81mf/7559A94N/x8gcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELm0zRNC31wPl/2WdiBFrw+w7iH/z/uvPPOYVvff//9sK2rOXv27LCtvXv3Dtsaac+ePUP3fvnll6F725HvQraDRe+hlz8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQjbWfQBg59jc3By29dZbbw3bWqW9e/eu+wgA/4qXPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAELKx7gMAO8eBAweGbR06dGjYFgCL8/IHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCNtZ9AGC5rrnmmmFbR48eHbbFWG+++eawrfPnzw/bArYfL38AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBkPk3TtO5DAACwGl7+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQMjGoh+cz+fLPAc71DRNK/157uHfd+211w7bunTp0rAtZrMff/xx2Nb9998/bGtra2vYVoXvQraDRe+hlz8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABCyse4DAH917733Dt174403hu7VnTt3btjWwYMHh21tbW0N2wL+v3n5AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIRvrPgDwVw899NDQvYcffnjoXt277747bOvMmTPDtgAW5eUPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEzKdpmhb64Hy+7LOwAy14fYbZrvfwvvvuG7b1xRdfDNuazWaz66+/fujeKN9+++2wrbvvvnvY1tXcdtttw7bOnTs3bIv18l3IdrDoPfTyBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAEDIfJqmaaEPzufLPgs70ILXZ5jteg8//vjjYVtPPvnksK3RLl++PGzr8ccfH7b12WefDdu6mu16B1kv34VsB4veQy9/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgZGPdB4B1ufnmm4dt7d+/f9jWaNM0Ddt68cUXh22dOnVq2NZOdcsttwzbev7554dtHTp0aNgWy/HBBx8M2zp27Niwra+//nrYFsvj5Q8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAITMp2maFvrgfL7ss7ADLXh9hhl5Dz/66KNhW0899dSwrdG++eabYVsHDhwYtjXSKu/hyN/ByZMnh23deuutw7Zo+f3334dt3XXXXcO2zp07N2yrYtHvQi9/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgZD5N07TuQwAAsBpe/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQjYW/eB8Pl/mOdihpmla6c/bv3//sK3Tp08P29rc3By2Ndrrr78+bOuVV14ZtvXss88O2zpx4sSwras5f/78sK3du3cP24Lt4I477hi2debMmWFbFYv+TfbyBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQjbWfQD4O2644YZhW5ubm8O2Rrp8+fLQvWPHjm3LrSNHjgzbWqXdu3ev+whLN/oOnjp1atjWV199NWzrmWeeGba1b9++YVuwbF7+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAyMa6D8Dq3X777es+wj/28ssvr/sIS3f8+PGhe48++uiwrZdeemnYFrPZ1tbWsK3Tp08P2zp69OiwrdlsNvvuu++Gbb3wwgvDtvbt2zdsaye7ePHisK0//vhj2BbL4+UPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEbKz7AKzeDz/8sO4j/GOPPPLIuo+wdLt27Rq6d+LEiaF7o1y6dGnY1nXXXTds62pOnjw5bOvpp58etnXTTTcN23rwwQeHbc1mY+/gvn37hm3tZBcuXBi2dfz48WFbZ8+eHbbF8nj5AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBkPk3TtNAH5/Nln4UVWfCfHJbqwoULw7Z27do1bOtqPv3005X9rL/jgQceGLa1yt8n/8x77703bOu5554btsV6Lfr33csfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAImU/TNC30wfl82WdhRT788MNhW4cPHx62tYiff/552NaePXuGbQE7zyeffDJs64knnhi2tYgbb7xx2NZvv/02bIv1WjDpvPwBAJSIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AICQ+TRN00IfnM+XfRZ2oAWvzzBHjx4dtvXaa68N2wL+ty+//HLY1uHDh4dtbW1tDdta9Xehv8lcyaL30MsfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAImU/TNK37EAAArIaXPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABCysegH5/P5Ms/BDjVN00p/3ubm5rCte+65Z9jWq6++OmzrscceG7bF9vb+++8P2/rpp5+Gbb3zzjvDtmaz2ezXX38dtnXx4sVhWyOt+rvQ32SuZNF76OUPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEzKdpmhb64Hy+7LOwAy14fYZxD7mSVd5Dd5Ar8V3IdrDoPfTyBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQubTNE3rPgQAAKvh5Q8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACDkv4dFfAX0SOj9AAAAAElFTkSuQmCC", + "text/plain": [ + "<Figure size 800x800 with 16 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "def display_patches(patches, n_patches, image_size):\n", " fig, axes = plt.subplots(n_patches, n_patches, figsize=(8, 8))\n", @@ -227,27 +248,6 @@ "\n", "# Affichez les patches\n", "display_patches(patches[0], n_patches, (1, 7, 7))\n" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 653 - }, - "id": "_7IX7zJpBJah", - "outputId": "a1365c4e-98a7-4ea4-a89b-a741ea219a40" - }, - "execution_count": 119, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "<Figure size 800x800 with 16 Axes>" - ], - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAn8AAAJ8CAYAAACP2sdVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAQxUlEQVR4nO3dMYic1QKG4ZnrroKmkCAiwdhopVEIsdBGEANKUFEkhWgjYm3sxSjaiJBGi3QKEQsRMY0GxE4sgtionQFDUNbCgBgIIeJ/ayGXjHpmZve+z1MP3x42h9mX02Q+TdM0AwAg4T/rPgAAAKsj/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQsrHoB+fz+TLPwQ616v8gxj3kSlZ5D91BrsR3IdvBovfQyx8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEbKz7AAA7ycGDB4dtff7558O2jhw5Mmzr7bffHrY1m81mf/7559A94N/x8gcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELm0zRNC31wPl/2WdiBFrw+w7iH/z/uvPPOYVvff//9sK2rOXv27LCtvXv3Dtsaac+ePUP3fvnll6F725HvQraDRe+hlz8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQjbWfQBg59jc3By29dZbbw3bWqW9e/eu+wgA/4qXPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAELKx7gMAO8eBAweGbR06dGjYFgCL8/IHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCNtZ9AGC5rrnmmmFbR48eHbbFWG+++eawrfPnzw/bArYfL38AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBkPk3TtO5DAACwGl7+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQMjGoh+cz+fLPAc71DRNK/157uHfd+211w7bunTp0rAtZrMff/xx2Nb9998/bGtra2vYVoXvQraDRe+hlz8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABCyse4DAH917733Dt174403hu7VnTt3btjWwYMHh21tbW0N2wL+v3n5AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIRvrPgDwVw899NDQvYcffnjoXt277747bOvMmTPDtgAW5eUPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEzKdpmhb64Hy+7LOwAy14fYbZrvfwvvvuG7b1xRdfDNuazWaz66+/fujeKN9+++2wrbvvvnvY1tXcdtttw7bOnTs3bIv18l3IdrDoPfTyBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAEDIfJqmaaEPzufLPgs70ILXZ5jteg8//vjjYVtPPvnksK3RLl++PGzr8ccfH7b12WefDdu6mu16B1kv34VsB4veQy9/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgZGPdB4B1ufnmm4dt7d+/f9jWaNM0Ddt68cUXh22dOnVq2NZOdcsttwzbev7554dtHTp0aNgWy/HBBx8M2zp27Niwra+//nrYFsvj5Q8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAITMp2maFvrgfL7ss7ADLXh9hhl5Dz/66KNhW0899dSwrdG++eabYVsHDhwYtjXSKu/hyN/ByZMnh23deuutw7Zo+f3334dt3XXXXcO2zp07N2yrYtHvQi9/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgZD5N07TuQwAAsBpe/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQjYW/eB8Pl/mOdihpmla6c/bv3//sK3Tp08P29rc3By2Ndrrr78+bOuVV14ZtvXss88O2zpx4sSwras5f/78sK3du3cP24Lt4I477hi2debMmWFbFYv+TfbyBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQjbWfQD4O2644YZhW5ubm8O2Rrp8+fLQvWPHjm3LrSNHjgzbWqXdu3ev+whLN/oOnjp1atjWV199NWzrmWeeGba1b9++YVuwbF7+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAyMa6D8Dq3X777es+wj/28ssvr/sIS3f8+PGhe48++uiwrZdeemnYFrPZ1tbWsK3Tp08P2zp69OiwrdlsNvvuu++Gbb3wwgvDtvbt2zdsaye7ePHisK0//vhj2BbL4+UPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEbKz7AKzeDz/8sO4j/GOPPPLIuo+wdLt27Rq6d+LEiaF7o1y6dGnY1nXXXTds62pOnjw5bOvpp58etnXTTTcN23rwwQeHbc1mY+/gvn37hm3tZBcuXBi2dfz48WFbZ8+eHbbF8nj5AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBkPk3TtNAH5/Nln4UVWfCfHJbqwoULw7Z27do1bOtqPv3005X9rL/jgQceGLa1yt8n/8x77703bOu5554btsV6Lfr33csfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAImU/TNC30wfl82WdhRT788MNhW4cPHx62tYiff/552NaePXuGbQE7zyeffDJs64knnhi2tYgbb7xx2NZvv/02bIv1WjDpvPwBAJSIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AICQ+TRN00IfnM+XfRZ2oAWvzzBHjx4dtvXaa68N2wL+ty+//HLY1uHDh4dtbW1tDdta9Xehv8lcyaL30MsfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAImU/TNK37EAAArIaXPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABCysegH5/P5Ms/BDjVN00p/3ubm5rCte+65Z9jWq6++OmzrscceG7bF9vb+++8P2/rpp5+Gbb3zzjvDtmaz2ezXX38dtnXx4sVhWyOt+rvQ32SuZNF76OUPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEzKdpmhb64Hy+7LOwAy14fYZxD7mSVd5Dd5Ar8V3IdrDoPfTyBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQubTNE3rPgQAAKvh5Q8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACDkv4dFfAX0SOj9AAAAAElFTkSuQmCC\n" - }, - "metadata": {} - } ] }, { @@ -265,29 +265,20 @@ }, { "cell_type": "code", + "execution_count": 120, + "metadata": { + "id": "RhglaVPb59Ll" + }, + "outputs": [], "source": [ "def features_embedding(patch,applatisseur,class_embaded):\n", " embedded_patches = applatisseur(flattened_patches)\n", " return(torch.cat((class_embaded, embedded_patches), dim=0))" - ], - "metadata": { - "id": "RhglaVPb59Ll" - }, - "execution_count": 120, - "outputs": [] + ] }, { "cell_type": "code", - "source": [ - "hidden_d=8\n", - "batch_size=patches.shape[1]\n", - "flattened_patches = patches.view(batch_size, -1)\n", - "linear_layer = nn.Linear(flattened_patches.size(1), hidden_d)\n", - "classe_embedded = torch.rand((1,hidden_d))\n", - "\n", - "features_emb=features_embedding(patches,linear_layer,classe_embedded)\n", - "print(features_emb)" - ], + "execution_count": 121, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -296,11 +287,10 @@ "id": "ayIaewry62c-", "outputId": "5c13d8d8-87f9-4353-ca93-08cf8aad095d" }, - "execution_count": 121, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "tensor([[ 0.0886, 0.1783, 0.7162, 0.6681, 0.3244, 0.4643, 0.1293, 0.5872],\n", " [ 0.0779, -0.0984, -0.0306, 0.0446, -0.0793, 0.0205, -0.1166, 0.0932],\n", @@ -322,6 +312,16 @@ " grad_fn=<CatBackward0>)\n" ] } + ], + "source": [ + "hidden_d=8\n", + "batch_size=patches.shape[1]\n", + "flattened_patches = patches.view(batch_size, -1)\n", + "linear_layer = nn.Linear(flattened_patches.size(1), hidden_d)\n", + "classe_embedded = torch.rand((1,hidden_d))\n", + "\n", + "features_emb=features_embedding(patches,linear_layer,classe_embedded)\n", + "print(features_emb)" ] }, { @@ -374,10 +374,7 @@ }, { "cell_type": "code", - "source": [ - "positional_emb=get_positional_embeddings(17,8)\n", - "print(positional_emb)" - ], + "execution_count": 123, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -386,11 +383,10 @@ "id": "PHvqYlLZFp6o", "outputId": "583e9e21-81c9-481f-c235-de5180e88f98" }, - "execution_count": 123, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "tensor([[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,\n", " 1.0000e+00, 0.0000e+00, 1.0000e+00],\n", @@ -428,15 +424,15 @@ " 9.8723e-01, 1.5999e-02, 9.9987e-01]])\n" ] } + ], + "source": [ + "positional_emb=get_positional_embeddings(17,8)\n", + "print(positional_emb)" ] }, { "cell_type": "code", - "source": [ - "input_vect=features_emb+positional_emb\n", - "print(input_vect.shape)\n", - "print(input_vect)" - ], + "execution_count": 124, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -445,11 +441,10 @@ "id": "L6sGNOSl-r8x", "outputId": "726a15e7-958f-4273-f4eb-33867a8b314c" }, - "execution_count": 124, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "torch.Size([17, 8])\n", "tensor([[ 0.0886, 1.1783, 0.7162, 1.6681, 0.3244, 1.4643, 0.1293, 1.5872],\n", @@ -472,6 +467,11 @@ " grad_fn=<AddBackward0>)\n" ] } + ], + "source": [ + "input_vect=features_emb+positional_emb\n", + "print(input_vect.shape)\n", + "print(input_vect)" ] }, { @@ -768,8 +768,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Using device: cuda (Tesla T4)\n" ] @@ -813,70 +813,70 @@ }, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Epoch 1/5 loss: 1.83\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Epoch 2/5 loss: 1.78\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Epoch 3/5 loss: 1.74\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Epoch 4/5 loss: 1.70\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ " " ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Epoch 5/5 loss: 1.68\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "\r" ] @@ -920,32 +920,32 @@ "cell_type": "code", "execution_count": 137, "metadata": { - "id": "h55dVGGhOaPI", "colab": { "base_uri": "https://localhost:8080/", "height": 0 }, + "id": "h55dVGGhOaPI", "outputId": "db5397d7-2402-4f30-c56a-d8e714a89b94" }, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ " " ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Test loss: 1.71\n", "Test accuracy: 77.85%\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "\r" ] @@ -973,80 +973,6 @@ " print(f\"Test accuracy: {correct / total * 100:.2f}%\")\n" ] }, - { - "cell_type": "code", - "source": [ - "# track test loss\n", - "test_loss = 0.0\n", - "class_correct_NET = list(0.0 for i in range(10))\n", - "class_total_NET = list(0.0 for i in range(10))\n", - "\n", - "import torch.optim as optim\n", - "\n", - "criterion = nn.CrossEntropyLoss() # specify loss function\n", - "optimizer = optim.SGD(model.parameters(), lr=0.01) # specify optimizer\n", - "\n", - "\n", - "model.eval()\n", - "# iterate over test data\n", - "for data, target in test_loader:\n", - " # move tensors to GPU if CUDA is available\n", - " if train_on_gpu:\n", - " data, target = data.cuda(), target.cuda()\n", - " # forward pass: compute predicted outputs by passing inputs to the model\n", - " output = model(data)\n", - " # calculate the batch loss\n", - " loss = criterion(output, target)\n", - " # update test loss\n", - " test_loss += loss.item() * data.size(0)\n", - " # convert output probabilities to predicted class\n", - " _, pred = torch.max(output, 1)\n", - " # compare predictions to true label\n", - " correct_tensor = pred.eq(target.data.view_as(pred))\n", - " correct = (\n", - " np.squeeze(correct_tensor.numpy())\n", - " if not train_on_gpu\n", - " else np.squeeze(correct_tensor.cpu().numpy())\n", - " )\n", - " # calculate test accuracy for each object class\n", - " for i in range(batch_size):\n", - " label = target.data[i]\n", - " class_correct_NET[label] += correct[i].item()\n", - " class_total_NET[label] += 1\n", - "\n", - "# average test loss\n", - "test_loss = test_loss / len(test_loader)\n", - "print(\"Test Loss: {:.6f}\\n\".format(test_loss))\n", - "\n", - "for i in range(10):\n", - " if class_total_NET[i] > 0:\n", - " print(\n", - " \"Test Accuracy of %5s: %2d%% (%2d/%2d)\"\n", - " % (\n", - " classes[i],\n", - " 100 * class_correct_NET[i] / class_total_NET[i],\n", - " np.sum(class_correct_NET[i]),\n", - " np.sum(class_total_NET[i]),\n", - " )\n", - " )\n", - " else:\n", - " print(\"Test Accuracy of %5s: N/A (no training examples)\" % (classes[i]))\n", - "\n", - "print(\n", - " \"\\nTest Accuracy (Overall): %2d%% (%2d/%2d)\"\n", - " % (\n", - " 100.0 * np.sum(class_correct_NET) / np.sum(class_total_NET),\n", - " np.sum(class_correct_NET),\n", - " np.sum(class_total_NET),\n", - " )\n", - ")" - ], - "metadata": { - "id": "mVrKCa4vjWfy" - }, - "execution_count": null, - "outputs": [] - }, { "cell_type": "markdown", "metadata": { @@ -1087,4 +1013,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +}