diff --git a/TD3_Vision_Transformer_rendu.ipynb b/TD3_Vision_Transformer_rendu.ipynb index 89c3c59f9839d4955feb2a36abae64f989432284..ca5f09da1a82024bca1ce29490c77f02d83d0981 100644 --- a/TD3_Vision_Transformer_rendu.ipynb +++ b/TD3_Vision_Transformer_rendu.ipynb @@ -75,7 +75,101 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Defaulting to user installation because normal site-packages is not writeable\n", + "Requirement already satisfied: numpy in /home/cosserat/.local/lib/python3.10/site-packages (1.26.3)\n", + "Defaulting to user installation because normal site-packages is not writeable\n", + "Requirement already satisfied: torch in /home/cosserat/.local/lib/python3.10/site-packages (2.1.2)\n", + "Requirement already satisfied: filelock in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (3.13.1)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (12.1.3.1)\n", + "Requirement already satisfied: sympy in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (1.12)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (12.1.0.106)\n", + "Requirement already satisfied: networkx in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (3.2.1)\n", + "Requirement already satisfied: fsspec in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (2023.12.2)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: jinja2 in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (3.1.3)\n", + "Requirement already satisfied: triton==2.1.0 in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (2.1.0)\n", + "Requirement already satisfied: typing-extensions in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (4.9.0)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (2.18.1)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /home/cosserat/.local/lib/python3.10/site-packages (from torch) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /home/cosserat/.local/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.3.101)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/cosserat/.local/lib/python3.10/site-packages (from jinja2->torch) (2.1.3)\n", + "Requirement already satisfied: mpmath>=0.19 in /home/cosserat/.local/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n", + "Defaulting to user installation because normal site-packages is not writeable\n", + "Requirement already satisfied: torchvision in /home/cosserat/.local/lib/python3.10/site-packages (0.16.2)\n", + "Requirement already satisfied: numpy in /home/cosserat/.local/lib/python3.10/site-packages (from torchvision) (1.26.3)\n", + "Requirement already satisfied: torch==2.1.2 in /home/cosserat/.local/lib/python3.10/site-packages (from torchvision) (2.1.2)\n", + "Requirement already satisfied: requests in /usr/lib/python3/dist-packages (from torchvision) (2.25.1)\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/lib/python3/dist-packages (from torchvision) (9.0.1)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (12.1.0.106)\n", + "Requirement already satisfied: filelock in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (3.13.1)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (11.4.5.107)\n", + "Requirement already satisfied: sympy in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (1.12)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (12.1.105)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (12.1.105)\n", + "Requirement already satisfied: typing-extensions in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (4.9.0)\n", + "Requirement already satisfied: networkx in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (3.2.1)\n", + "Requirement already satisfied: triton==2.1.0 in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (2.1.0)\n", + "Requirement already satisfied: jinja2 in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (3.1.3)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (2.18.1)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (12.1.105)\n", + "Requirement already satisfied: fsspec in /home/cosserat/.local/lib/python3.10/site-packages (from torch==2.1.2->torchvision) (2023.12.2)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /home/cosserat/.local/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.1.2->torchvision) (12.3.101)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/cosserat/.local/lib/python3.10/site-packages (from jinja2->torch==2.1.2->torchvision) (2.1.3)\n", + "Requirement already satisfied: mpmath>=0.19 in /home/cosserat/.local/lib/python3.10/site-packages (from sympy->torch==2.1.2->torchvision) (1.3.0)\n", + "Defaulting to user installation because normal site-packages is not writeable\n", + "Collecting matplotlib\n", + " Downloading matplotlib-3.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.6 MB)\n", + "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.6/11.6 MB\u001b[0m \u001b[31m9.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: numpy<2,>=1.21 in /home/cosserat/.local/lib/python3.10/site-packages (from matplotlib) (1.26.3)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /usr/lib/python3/dist-packages (from matplotlib) (2.4.7)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /home/cosserat/.local/lib/python3.10/site-packages (from matplotlib) (2.8.2)\n", + "Collecting kiwisolver>=1.3.1\n", + " Downloading kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.6 MB)\n", + "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m2 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m:01\u001b[0m\n", + "\u001b[?25hCollecting fonttools>=4.22.0\n", + " Downloading fonttools-4.47.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6 MB)\n", + "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.6/4.6 MB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m0m eta \u001b[36m0:00:01\u001b[0m0:01\u001b[0m:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: pillow>=8 in /usr/lib/python3/dist-packages (from matplotlib) (9.0.1)\n", + "Collecting cycler>=0.10\n", + " Downloading cycler-0.12.1-py3-none-any.whl (8.3 kB)\n", + "Collecting contourpy>=1.0.1\n", + " Downloading contourpy-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (310 kB)\n", + "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m310.7/310.7 KB\u001b[0m \u001b[31m12.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: packaging>=20.0 in /home/cosserat/.local/lib/python3.10/site-packages (from matplotlib) (23.2)\n", + "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\n", + "Installing collected packages: kiwisolver, fonttools, cycler, contourpy, matplotlib\n", + "Successfully installed contourpy-1.2.0 cycler-0.12.1 fonttools-4.47.2 kiwisolver-1.4.5 matplotlib-3.8.2\n" + ] + } + ], + "source": [ + "!pip install numpy\n", + "!pip install torch\n", + "!pip install torchvision\n", + "!pip install matplotlib" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": { "id": "wEmbaOA4Okuo" }, @@ -89,7 +183,11 @@ "from torch.optim import Adam\n", "from torch.utils.data import DataLoader\n", "from torchvision.datasets.mnist import MNIST\n", - "from torchvision.transforms import ToTensor" + "from torchvision.transforms import ToTensor\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import random\n", + "from torchvision.transforms import ToPILImage" ] }, { @@ -103,7 +201,7 @@ }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 9, "metadata": { "id": "crfmWV8uc4wm" }, @@ -123,9 +221,16 @@ "test_loader = DataLoader(test_set, shuffle=False, batch_size=128)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Le code précédent à charger le Dataset de MNIST. Nous pouvons afficher un de ces éléments aléatoire :" + ] + }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -137,7 +242,7 @@ "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAARMUlEQVR4nO3cf4zXBf3A8dd5J+fB3am7BLzDQSBrcNlZ/ljzlCsNBdHcNKnW4kdquNGUiqvWMo0/XEpGjLIw2+wH9MO12gy1HVltSQPKfqBTM9AgmLemAaJAxL2/fzRfXwmUz/vj/SDv8djY9PN5vz7v14fNe/r+3N27piiKIgAgIo4b6gUAOHaIAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAsesCRMmxLx584Z6DRhWRIFBt3nz5liwYEFMnDgxTjjhhGhubo7Ozs5Yvnx57N27d6jXGzK9vb2xYMGCaGtrixNOOCEmTJgQ11xzzVCvxTBTN9QLMLysWbMmrr766qivr485c+bEW9/61vjXv/4Vv/nNb6K7uzsee+yxuOuuu4Z6zUG3bdu26OzsjIiI66+/Ptra2mLHjh2xYcOGId6M4UYUGDRPP/10fOADH4jx48fHQw89FKeeemo+t3DhwvjrX/8aa9asGcINh86CBQuirq4uNm7cGC0tLUO9DsOYj48YNLfffnvs2bMnvvWtbx0ShJedfvrpceONN77q/PPPPx+LFy+OM844IxobG6O5uTlmzpwZf/rTnw47dsWKFdHe3h4jR46Mk08+Oc4+++xYvXp1Pv/CCy/EokWLYsKECVFfXx+jR4+O6dOnxyOPPHLI66xfvz5mzJgRJ554YowcOTK6urri4Ycfruj9bt26NZ544omjHvfEE0/EAw88EN3d3dHS0hL79u2LAwcOVHQO6G+iwKC57777YuLEiXHeeedVNb9ly5b46U9/Gpdddll8+ctfju7u7ti0aVN0dXXFjh078rhvfvObccMNN8TUqVPjK1/5SnzhC1+IM888M9avX5/HXH/99fH1r389rrrqqrjzzjtj8eLF0dDQEI8//nge89BDD8W0adNi9+7dcfPNN8ett94aO3fujAsvvLCij3XmzJkTU6ZMOepxa9eujYiIMWPGxEUXXRQNDQ3R0NAQM2fOjGeeeabE3xD0gwIGwa5du4qIKK644oqKZ8aPH1/MnTs3/33fvn3FwYMHDznm6aefLurr64slS5bkY1dccUXR3t7+mq994oknFgsXLnzV5/v6+orJkycXl1xySdHX15ePv/TSS8Wb3/zmYvr06Ufdv6urq6jkP7EbbrihiIiipaWlmDFjRvHDH/6wWLp0adHY2FhMmjSpePHFF4/6GtBffE+BQbF79+6IiGhqaqr6Nerr6/OfDx48GDt37ozGxsZ4y1vecsjHPieddFL8/e9/j40bN8Y555xzxNc66aSTYv369bFjx45obW097Pk//vGP8dRTT8XnPve5eO655w557qKLLorvfve70dfXF8cd9+oX27/61a8qel979uyJiIixY8fGmjVr8jXHjRsXH/zgB2P16tVx7bXXVvRa8Hr5+IhB0dzcHBH/+Sy/Wn19fbFs2bKYPHly1NfXx5ve9KY45ZRT4s9//nPs2rUrj/v0pz8djY2Nce6558bkyZNj4cKFh30f4Pbbb49HH300TjvttDj33HPjlltuiS1btuTzTz31VEREzJ07N0455ZRD/tx9992xf//+Q875ejQ0NERExOzZsw+JzNVXXx11dXWxbt26fjkPVEIUGBTNzc3R2toajz76aNWvceutt8YnPvGJmDZtWnzve9+Ln//859HT0xPt7e3R19eXx02ZMiWefPLJ+MEPfhDnn39+/PjHP47zzz8/br755jxm9uzZsWXLllixYkW0trbG0qVLo729PR544IGIiHy9pUuXRk9PzxH/NDY2Vv1eXunlK5UxY8Yc8nhtbW20tLTEP//5z345D1RkqD+/Yvj46Ec/WkREsW7duoqO/+/vKXR0dBTvfve7Dzuura2t6OrqetXX2b9/fzFr1qyitra22Lt37xGP6e3tLdra2orOzs6iKIpiw4YNRUQUK1eurGjX1+PBBx8sIqK46aabDtu7tra2uO666wZ8B3iZKwUGzac+9akYNWpUXHvttdHb23vY85s3b47ly5e/6nxtbW0URXHIY/fee29s3779kMf++3sAI0aMiKlTp0ZRFHHgwIE4ePDgYR/9jB49OlpbW2P//v0REXHWWWfFpEmT4ktf+lJ+5v9K//jHP177zUblP5L6rne9K0aPHh2rVq2Kffv25eP33HNPHDx4MKZPn37U14D+4hvNDJpJkybF6tWr4/3vf39MmTLlkN9oXrduXdx7772vea+jyy67LJYsWRLz58+P8847LzZt2hSrVq2KiRMnHnLcxRdfHGPHjo3Ozs4YM2ZMPP744/HVr341Zs2aFU1NTbFz584YN25cvO9974uOjo5obGyMtWvXxsaNG+OOO+6IiIjjjjsu7r777pg5c2a0t7fH/Pnzo62tLbZv3x6//OUvo7m5Oe67777XfL9z5syJX//614eF7L/V19fH0qVLY+7cuTFt2rT48Ic/HFu3bo3ly5fHBRdcEFdeeWVlf8HQH4b4SoVh6C9/+Utx3XXXFRMmTChGjBhRNDU1FZ2dncWKFSuKffv25XFH+pHUT37yk8Wpp55aNDQ0FJ2dncVvf/vboqur65CPj1auXFlMmzataGlpKerr64tJkyYV3d3dxa5du4qi+M/HMt3d3UVHR0fR1NRUjBo1qujo6CjuvPPOw3b9wx/+UFx55ZX5WuPHjy9mz55d/OIXvzjq+6z0R1Jf9v3vf7/o6Ogo6uvrizFjxhQf+9jHit27d1c8D/2hpiiO8r8xAAwbvqcAQBIFAJIoAJBEAYAkCgAkUQAgVfzLazU1NQO5BwADrJLfQHClAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAECqG+oF4Fjxnve8p/RMT09P6ZlFixaVnomIWLFiRemZvr6+qs7F8OVKAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASDVFURQVHVhTM9C7QL+ZOnVq6Zn777+/9Mxpp51WeqZara2tpWd6e3sHYBP+V1Xy5d6VAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAUt1QLwBHc/zxx5eeue2220rPDObN7eBY5UoBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgDJDfE45p111lmlZy699NIB2ATe+FwpAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAguSEeg6a2traquZtuuqmfNxlat912W1Vzzz//fD9vAodzpQBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgFRTFEVR0YE1NQO9C29wI0aMqGpu7969/bxJ/3nmmWdKz3R2dlZ1rmeffbaqOXhZJV/uXSkAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgCpbqgX4H/T2WefXXpmyZIlA7BJ/9m2bVvpmYsvvrj0jLudcixzpQBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgOSGeFTlwgsvLD1zySWXDMAm/eeee+4pPbN58+b+XwSGkCsFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgCkmqIoiooOrKkZ6F0YIu985ztLz/T09JSeGTlyZOmZam3atKn0zOWXX156Ztu2baVnYKhU8uXelQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAFLdUC/A0Fu8eHHpmcG8ud2BAwdKz3zmM58pPePmduBKAYBXEAUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACR3SX2DGT16dOmZt7/97QOwyeGKoqhq7uMf/3jpmQcffLCqc5U1duzY0jPz58+v6lyXXnppVXPHqq1bt1Y1t2zZstIzv/vd76o613DkSgGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAKmmqPAuZTU1NQO9C/3gRz/6UemZq666agA2OdwjjzxS1dw555zTz5sc2Tve8Y7SMz/5yU9Kz4wbN670DP/vhRdeKD1zxhlnlJ7Ztm1b6ZljXSVf7l0pAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAg1Q31AhzZmWeeWdXce9/73v5dpB/df//9g3auD33oQ6Vnli9fXnrm5JNPLj3D69PU1FR6ZsSIEQOwyRuTKwUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACQ3xDtGjRo1qqq5448/vp83ObIDBw6Unlm2bFlV57rjjjtKz9x4442lZ2pqakrPDKZq/s57enpKz2zYsKH0zOzZs0vPTJ06tfQMA8+VAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkhviDYJqblL32c9+dgA26T8rV64sPTNr1qyqzrVo0aKq5gZDb29v6Znf//73VZ3ri1/8YumZxx57rPTMvHnzSs8M5s3tXnrppdIz//73vwdgkzcmVwoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEByl9RBMH78+NIzM2bMGIBN+k9jY2Ppme985zsDsMmR7d+/v/TM2rVrS89cc801pWdaWlpKz0REXHDBBaVnvvGNb5SeGaw7nu7Zs6equWru0Pu3v/2tqnMNR64UAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQaoqiKCo6sKZmoHd5wzr99NNLzzz55JMDsMnw8eKLL5aeefjhhwdgk8N1dnZWNTdq1Kh+3mRoffvb365q7iMf+Ug/bzJ8VPLl3pUCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSG+INgrFjx5ae2bhxY1Xnam1trWoOXo+f/exnpWfmzJlT1bl27dpV1RxuiAdASaIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJDqhnqB4eDZZ58tPXPXXXdVda5bbrmlqjnemNatW1d6Zt68eaVnent7S8/s2bOn9AwDz5UCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQBSTVEURUUH1tQM9C68Ql1ddfcqfNvb3lZ65vOf/3zpmcsvv7z0zBvRqlWrSs9s3769qnN97WtfKz3z3HPPlZ7Zu3dv6Rn+N1Ty5d6VAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkNwlFWCYcJdUAEoRBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAINVVemBRFAO5BwDHAFcKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKAKT/A41pVEW0tQbgAAAAAElFTkSuQmCC", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAQdElEQVR4nO3df6zVdf3A8dcBxk259wqjKUIOBJzgrV2brRqXcZcJccNJs5Sy7ZJLw81lLb2ZW43iD/6Afsgwmj9amxXLqOmmrvwRlhkNdCWRSahAEa2b4EBENLv3/f2jr69FCJ7PkXsvPx6P7W5y7+d1zvtscJ/3fc65b2ullBIAEBHDhnoBABw7RAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRIFj1qRJk+JTn/rUUC8DTiqiwKB77rnnYtGiRTF58uR429veFq2trdHR0RErVqyIAwcODPXyhtxjjz0WtVotarVa7Nq1a6iXw0lmxFAvgJPL/fffH5dddlk0NTVFd3d3vPOd74x//etf8dhjj0VPT0889dRTcdtttw31ModMf39/fPazn41Ro0bF/v37h3o5nIREgUGzbdu2+PjHPx4TJ06MtWvXxplnnplfu/baa+PZZ5+N+++/fwhXOPRuu+222LFjR1x11VWxYsWKoV4OJyFPHzFoli1bFi+99FJ897vfPSgIr5s6dWp87nOfO+z8Cy+8EDfccEO8613viubm5mhtbY2urq7YuHHjIdeuXLky2tra4tRTT40xY8bEe97znli9enV+fd++ffH5z38+Jk2aFE1NTXH66afH7Nmz43e/+91Bt7N+/fqYO3dunHbaaXHqqadGZ2dn/OY3v6nr8f71r3+NzZs313Xt64/vy1/+cixZsiRGjx5d9xwcTaLAoLn33ntj8uTJMWPGjIbmt27dGvfcc09cfPHF8c1vfjN6enpi06ZN0dnZGX//+9/zuttvvz2uu+66OO+88+Lmm2+Or33ta3H++efH+vXr85prrrkmvvOd78RHP/rRWLVqVdxwww1xyimnxNNPP53XrF27NmbNmhUvvvhiLF68OJYuXRp79uyJCy+8MDZs2PCm6+3u7o7p06fX/fi+8pWvxLhx42LRokV1z8BRV2AQ7N27t0REmT9/ft0zEydOLAsXLsw/v/LKK6Wvr++ga7Zt21aamprKkiVL8nPz588vbW1tR7zt0047rVx77bWH/Xp/f38555xzyoc+9KHS39+fn3/55ZfL2WefXWbPnv2m6+/s7Cz1/hPbuHFjGT58eHnggQdKKaUsXry4RER5/vnn65qHo8VOgUHx4osvRkRES0tLw7fR1NQUw4b9569sX19f7N69O5qbm+Pcc8896Gmf0aNHx9/+9rd4/PHHD3tbo0ePjvXr1x+0w/hvTz75ZDzzzDNxxRVXxO7du2PXrl2xa9eu2L9/f3zwgx+MRx99NPr7+4+43l/+8pdR6vx/WF133XXR1dUVc+bMqet6GCheaGZQtLa2RsR/nstvVH9/f6xYsSJWrVoV27Zti76+vvza2LFj879vvPHGePjhh+O9731vTJ06NebMmRNXXHFFdHR05DXLli2LhQsXxllnnRUXXHBBfPjDH47u7u6YPHlyREQ888wzERGxcOHCw65n7969MWbMmIYfz+vuuuuuWLduXfzxj398y7cFb5WdAoOitbU1xo8f/5a+8S1dujS+8IUvxKxZs+IHP/hBPPDAA/HQQw9FW1vbQT+1T58+Pf785z/Hj370o5g5c2b89Kc/jZkzZ8bixYvzmssvvzy2bt0aK1eujPHjx8fy5cujra0tfvazn0VE5O0tX748HnrooTf8aG5ubvix/Leenp647LLLYuTIkbF9+/bYvn177NmzJyIiduzYcdjdDAyIoX7+ipPHZz7zmRIRZd26dXVd/7+vKbS3t5cPfOADh1w3YcKE0tnZedjbefXVV8u8efPK8OHDy4EDB97wmt7e3jJhwoTS0dFRSillw4YNJSLKrbfeWtda34qIOOJHe3v7gK8BXmenwKD54he/GKNGjYqrrroqent7D/n6c889d8T35g8fPvyQ5+jXrFkTO3fuPOhzu3fvPujPI0eOjPPOOy9KKfHaa69FX19f7N2796BrTj/99Bg/fny8+uqrERFxwQUXxJQpU+LrX/96vPTSS4es5fnnnz/yg43635J69913H/KxYMGCiIi4884741vf+tab3gYcLV5TYNBMmTIlVq9eHQsWLIjp06cf9BvN69atizVr1hzxrKOLL744lixZEldeeWXMmDEjNm3aFD/84Q/zdYDXzZkzJ8aNGxcdHR1xxhlnxNNPPx233HJLzJs3L1paWmLPnj3xjne8Iz72sY9Fe3t7NDc3x8MPPxyPP/54fOMb34iIiGHDhsUdd9wRXV1d0dbWFldeeWVMmDAhdu7cGY888ki0trbGvffee8TH293dHb/61a/e9MXmj3zkI4d87sknn4yIiK6urnj7299+xHk4qoZ4p8JJaMuWLeXqq68ukyZNKiNHjiwtLS2lo6OjrFy5srzyyit53Ru9JfX6668vZ555ZjnllFNKR0dH+e1vf1s6OzsPevro1ltvLbNmzSpjx44tTU1NZcqUKaWnp6fs3bu3lPKfp5N6enpKe3t7aWlpKaNGjSrt7e1l1apVh6z197//fbn00kvztiZOnFguv/zy8otf/OJNH2eVt6T+L29JZajUSqnzPXMAnPC8pgBAEgUAkigAkEQBgCQKACRRACDV/ctrtVptINcBwACr5zcQ7BQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgCkEUO9ABgIZ5xxRuWZ3t7eAVjJyeFLX/pS5ZmlS5c2dF+LFi2qPHP77bc3dF8nIzsFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkB+IxaObMmdPQ3M0333x0F3IYs2fPrjyzc+fOAVjJ8aetra3yTCmlofvq6uqqPONAvPrZKQCQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAMkpqTRk7NixlWfuuOOOhu5rwoQJlWf27dtXeaalpaXyzLGukce0aNGiyjOf/OQnK89s2bKl8kxExOrVqxuaoz52CgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASLVSSqnrwlptoNfCceTb3/525ZlrrrlmAFbyxj7xiU9Unvnxj388ACsZWtOmTas889RTT1WeaeT7w/nnn195JiLiD3/4Q0NzRNTz7d5OAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIAacRQL4Dj06WXXjpo9/WXv/yl8swTTzwxACs5/pxzzjmDcj+9vb2VZ/75z38OwEp4q+wUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQHIh3gpk0aVLlmeXLl1eeGTduXOWZHTt2VJ6JiOju7q48s3Xr1obu61g2bdq0yjN33nln5ZlarVZ55uqrr648849//KPyDAPPTgGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEhOST1GTZ06taG5Rx55pPLM+PHjK8+UUirPLFu2rPJMRMTmzZsrzwwbVv3nnf7+/sozjXjf+97X0NyDDz5Yeaa5ubnyzOLFiyvP3HfffZVnODbZKQCQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIDkQ7xjV1dXV0Fwjh9sNlhUrVgza3M9//vPKM/fcc0/lmUYOqTv77LMrz0Q0drjdvn37Ks+sWbOm8gwnDjsFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkB+Idoy688MKhXsJxbe7cuYMys3///soztVqt8kxEY4fb3XTTTZVnNm/eXHmGE4edAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAUq2UUuq6sMFDvGjMtGnTGpp7//vfX3lmy5YtDd3XYLnooosqz5x11lmVZ+bNm1d5Zty4cZVn6vwnd4gXXnih8sz3v//9yjObNm2qPPO9732v8gyDr56/e3YKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIDsSD/zdz5szKM7/+9a8rzzR6IF4jtm/fXnnm3//+99FfyBtYvXp1Q3Nf/epXj+5CTiIOxAOgElEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEAaMdQLgIEwatSoyjMLFiyoPNPIiaf79++vPBMRccstt1SeWb58eeWZvr6+yjOcOOwUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQaqXOE71qtdpArwWOmieeeKLyzLvf/e7KMw8++GDlmeuvv77yTETEn/70p4bm4HX1fLu3UwAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBox1AuANzNjxozKM+eee+4ArORQjz76aOUZB9txLLNTACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAqpVSSl0X1moDvRZ4Q3fffXflmUsuuaTyzH333Vd5Zv78+ZVnYKjU8+3eTgGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAMmBeDSktbW18sxPfvKThu7roosuqjyzcePGyjNz586tPNPb21t5BoaKA/EAqEQUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRgz1Ajg+vfbaa5VnJk+e3NB91XmQ70E2bNhQecaJp2CnAMB/EQUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgORAPBpy4MCByjNr165t6L4+/elPV5559tlnG7ovONnZKQCQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAINVKKaWuC2u1gV4LJ7hhwxr7GWTMmDGVZ15++eXKM40c8gfHk3q+3dspAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgORDvBDN69OjKMzfeeGPlmZtuuqnyDDC0HIgHQCWiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGA5JRUgJOEU1IBqEQUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAGlEvReWUgZyHQAcA+wUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEj/B2c1ZSjQql30AAAAAElFTkSuQmCC", "text/plain": [ "<Figure size 640x480 with 1 Axes>" ] @@ -147,10 +252,6 @@ } ], "source": [ - "import matplotlib.pyplot as plt\n", - "import random\n", - "from torchvision.transforms import ToPILImage\n", - "\n", "to_pil = ToPILImage()\n", "random_index = random.randint(0, len(train_set) - 1)\n", "image, label = train_set[random_index]\n", @@ -182,7 +283,7 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 12, "metadata": { "id": "fxhHKKDFOoHp" }, @@ -208,9 +309,16 @@ " return patches" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Le code ci-dessus coupe une image en n² images que l'on visualise avec la fonction ci-dessous." + ] + }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -222,7 +330,7 @@ "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAn8AAAJ8CAYAAACP2sdVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAQxUlEQVR4nO3dMYic1QKG4ZnrroKmkCAiwdhopVEIsdBGEANKUFEkhWgjYm3sxSjaiJBGi3QKEQsRMY0GxE4sgtionQFDUNbCgBgIIeJ/ayGXjHpmZve+z1MP3x42h9mX02Q+TdM0AwAg4T/rPgAAAKsj/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQsrHoB+fz+TLPwQ616v8gxj3kSlZ5D91BrsR3IdvBovfQyx8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEbKz7AAA7ycGDB4dtff7558O2jhw5Mmzr7bffHrY1m81mf/7559A94N/x8gcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELm0zRNC31wPl/2WdiBFrw+w7iH/z/uvPPOYVvff//9sK2rOXv27LCtvXv3Dtsaac+ePUP3fvnll6F725HvQraDRe+hlz8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQjbWfQBg59jc3By29dZbbw3bWqW9e/eu+wgA/4qXPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAELKx7gMAO8eBAweGbR06dGjYFgCL8/IHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCNtZ9AGC5rrnmmmFbR48eHbbFWG+++eawrfPnzw/bArYfL38AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBkPk3TtO5DAACwGl7+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQMjGoh+cz+fLPAc71DRNK/157uHfd+211w7bunTp0rAtZrMff/xx2Nb9998/bGtra2vYVoXvQraDRe+hlz8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABCyse4DAH917733Dt174403hu7VnTt3btjWwYMHh21tbW0N2wL+v3n5AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIRvrPgDwVw899NDQvYcffnjoXt277747bOvMmTPDtgAW5eUPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEzKdpmhb64Hy+7LOwAy14fYbZrvfwvvvuG7b1xRdfDNuazWaz66+/fujeKN9+++2wrbvvvnvY1tXcdtttw7bOnTs3bIv18l3IdrDoPfTyBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAEDIfJqmaaEPzufLPgs70ILXZ5jteg8//vjjYVtPPvnksK3RLl++PGzr8ccfH7b12WefDdu6mu16B1kv34VsB4veQy9/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgZGPdB4B1ufnmm4dt7d+/f9jWaNM0Ddt68cUXh22dOnVq2NZOdcsttwzbev7554dtHTp0aNgWy/HBBx8M2zp27Niwra+//nrYFsvj5Q8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAITMp2maFvrgfL7ss7ADLXh9hhl5Dz/66KNhW0899dSwrdG++eabYVsHDhwYtjXSKu/hyN/ByZMnh23deuutw7Zo+f3334dt3XXXXcO2zp07N2yrYtHvQi9/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgZD5N07TuQwAAsBpe/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQjYW/eB8Pl/mOdihpmla6c/bv3//sK3Tp08P29rc3By2Ndrrr78+bOuVV14ZtvXss88O2zpx4sSwras5f/78sK3du3cP24Lt4I477hi2debMmWFbFYv+TfbyBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQjbWfQD4O2644YZhW5ubm8O2Rrp8+fLQvWPHjm3LrSNHjgzbWqXdu3ev+whLN/oOnjp1atjWV199NWzrmWeeGba1b9++YVuwbF7+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAyMa6D8Dq3X777es+wj/28ssvr/sIS3f8+PGhe48++uiwrZdeemnYFrPZ1tbWsK3Tp08P2zp69OiwrdlsNvvuu++Gbb3wwgvDtvbt2zdsaye7ePHisK0//vhj2BbL4+UPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEbKz7AKzeDz/8sO4j/GOPPPLIuo+wdLt27Rq6d+LEiaF7o1y6dGnY1nXXXTds62pOnjw5bOvpp58etnXTTTcN23rwwQeHbc1mY+/gvn37hm3tZBcuXBi2dfz48WFbZ8+eHbbF8nj5AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBkPk3TtNAH5/Nln4UVWfCfHJbqwoULw7Z27do1bOtqPv3005X9rL/jgQceGLa1yt8n/8x77703bOu5554btsV6Lfr33csfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAImU/TNC30wfl82WdhRT788MNhW4cPHx62tYiff/552NaePXuGbQE7zyeffDJs64knnhi2tYgbb7xx2NZvv/02bIv1WjDpvPwBAJSIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AICQ+TRN00IfnM+XfRZ2oAWvzzBHjx4dtvXaa68N2wL+ty+//HLY1uHDh4dtbW1tDdta9Xehv8lcyaL30MsfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAImU/TNK37EAAArIaXPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABCysegH5/P5Ms/BDjVN00p/3ubm5rCte+65Z9jWq6++OmzrscceG7bF9vb+++8P2/rpp5+Gbb3zzjvDtmaz2ezXX38dtnXx4sVhWyOt+rvQ32SuZNF76OUPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEzKdpmhb64Hy+7LOwAy14fYZxD7mSVd5Dd5Ar8V3IdrDoPfTyBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQubTNE3rPgQAAKvh5Q8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACDkv4dFfAX0SOj9AAAAAElFTkSuQmCC", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAn8AAAJ8CAYAAACP2sdVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAQ1klEQVR4nO3dvYtcZQOH4TlhO6MoFoYoQcVAUgVsFMEuCDZRbEQINvGrt9HGP8A0CoqFgkVsRASjFhJBQTvRQoUgBkQR8aMQIUQLTc5bC/uSiTw7s5P7uurhNw/Js2fvnCbTPM/zAgCAhD3rPgAAAKsj/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQsrXsB6dp2slzsKFW/R/EuIdsZ5X30B1kO56F7AbL3kNv/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQMjWug8ANM3zvO4jELfJd3Dk2Z944olhW6+++uqwLXaON38AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBkmud5XvchAABYDW/+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQMjWsh+cpmknz8GGmud5pd/nHl65++67b9jWCy+8MGxrpMOHD6/su2655ZZhWz/99NOwLa7cqVOnhm0dP3582NaqnT59etjWgw8+OGyLK7fs72Rv/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQMjWug8A/NuNN944dO+1114btnXzzTcP2zp//vywrVW69tpr132EjTPyz+zJJ58ctnX8+PFhW6t27ty5YVtvvPHGsC02gzd/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgZJrneV7qg9O002dhAy15fYYp3MOXX3556N5TTz01dG+URx55ZNjWm2++OWzrcgp3cLRDhw4N2zp79uywrU3+uzxy5Miwra+++mrYFuu17O9kb/4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAEDI1roPAPzbQw89tO4j/F8//PDDsK3PP/982Ba728GDB9d9hG39+uuvw7ZuuummYVvL+O2331b6fVxdvPkDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIGRr3QeAdbn11luHbZ08eXLY1r59+4ZtLRaLxY8//jhs69FHHx229d133w3bYqxDhw4N3Tt16tSwrWmahm099thjw7bee++9YVvL+OWXX1b6fVxdvPkDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAhW+s+AFyJO+64Y9jWxx9/PGxr//79w7bmeR62tVgsFs8///ywrW+++WbY1p49/u050l133TVs68MPPxy2tVgsFnv37h229dxzzw3bev/994dtwSbx9AUACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACFb6z4AXIn7779/2Nb+/fuHbe1mL7744q7c+uCDD4ZtrdLjjz8+bOvMmTPDtm677bZhW3v37h22tVgsFufPnx+29dZbbw3bgipv/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQMg0z/O87kMAALAa3vwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAIRsLfvBaZp28hxsqHmeV/p9p0+fHrZ17NixYVsjjf5ZW/Xf0Tqs8vl06dKlYVsXLlwYtjXyz+DixYvDthaLxeLZZ58dtvXKK68M2xpp1T9nfieznWXvoTd/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgZJrneV7qg9O002dhAy15fYY5fPjwsK2777572Na33347bGs3O3r06LCtAwcODNs6ceLEsK3L+fnnn4dt7du3b9jWyJ/F33//fdjWYrFYnDp1atjW119/PWzr9ddfH7a16meh38lsZ9l76M0fAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAImeZ5npf64DTt9FnYQEten2HcQ7azynt47733Dtv69NNPh22t+mfxSnz//ffDtv75559hWyMdPHhw3Uf4zzxXrx7LPge8+QMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACFb6z4A0HTNNdes+wj/ycMPPzxsa57nYVsXLlwYtvXSSy8N21osFouTJ08O27p48eKwrZH++OOPlX7fgQMHVvp9XF28+QMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgZJrneV7qg9O002dhAy15fYZxD68eX3zxxbCtO++8c9jW5Vy6dGnY1pkzZ4ZtPf3008O2zp49O2yrwrOQ3WDZe+jNHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACNla9wGAzXHPPfcM2zp06NCwrU31ySefDNs6e/bssC3g6ubNHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACJnmeZ6X+uA07fRZ2EBLXp9h3MP1euedd4ZtHTt2bNjWKu/Fu+++O2zrgQceGLbFenkWshssew+9+QMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHTPM/zug8BAMBqePMHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQraW/eA0TTt5DjbUPM8r/b7CPbzuuuuG7r399tvDto4ePTps68svvxy2deTIkWFbq1S4zxWehewGy95Db/4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAEDI1roPAPzb33//PXTv9ttvH7Y1z/Owrc8++2zY1pEjR4ZtXc40TSv7LoCd4M0fAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAI2Vr3AYB/++uvv4buffTRR8O2Tpw4MWzr3Llzw7YAWJ43fwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIGSa53le6oPTtNNnYQMteX2GcQ+v3J494/6Nd8MNNwzb+vPPP3fl1uW4g2zHs5DdYNl76M0fAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHTPM/zUh+cpp0+CxtoyeszzMh7eP311w/beuaZZ3blVsUq76FnIdvZ5GchV49l76E3fwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIGSa53le6oPTtNNnYQMteX2GcQ/ZzirvoTvIdjwL2Q2WvYfe/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAECI+AMACBF/AAAh4g8AIET8AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhIg/AIAQ8QcAECL+AABCxB8AQIj4AwAIEX8AACHiDwAgRPwBAISIPwCAEPEHABAi/gAAQsQfAEDINM/zvO5DAACwGt78AQCEiD8AgBDxBwAQIv4AAELEHwBAiPgDAAgRfwAAIeIPACBE/AEAhPwP7L2SB9Bj9vUAAAAASUVORK5CYII=", "text/plain": [ "<Figure size 800x800 with 16 Axes>" ] @@ -260,25 +368,27 @@ "\n", "Now that we have our flattened patches, we can map each of them through a Linear mapping. While each patch was a 4x4=16 dimensional vector, the linear mapping can map to any arbitrary vector size. Thus, we will use for this a parameter `hidden_d` for \"hidden dimension\".\n", "\n", - "In this example, we will use a hidden dimension of 8, but in principle, any number can be put here. We will thus be mapping each 16-dimensional patch to an 8-dimensional patch.\n" + "In this example, we will use a hidden dimension of 8, but in principle, any number can be put here. We will thus be mapping each 16-dimensional patch to an 8-dimensional patch.\n", + "\n", + "Chaque image sera définit par un vecteur de taille 8 calculer par une couche caché fully connected. De plus on ajoute un 17ème vecteur pour la position?\n" ] }, { "cell_type": "code", - "execution_count": 120, + "execution_count": 18, "metadata": { "id": "RhglaVPb59Ll" }, "outputs": [], "source": [ - "def features_embedding(patch,applatisseur,class_embaded):\n", + "def features_embedding(patch,applatisseur,class_embaded,flattened_patches):\n", " embedded_patches = applatisseur(flattened_patches)\n", " return(torch.cat((class_embaded, embedded_patches), dim=0))" ] }, { "cell_type": "code", - "execution_count": 121, + "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -292,23 +402,23 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor([[ 0.0886, 0.1783, 0.7162, 0.6681, 0.3244, 0.4643, 0.1293, 0.5872],\n", - " [ 0.0779, -0.0984, -0.0306, 0.0446, -0.0793, 0.0205, -0.1166, 0.0932],\n", - " [ 0.0683, -0.0921, 0.0700, -0.1282, -0.1309, 0.1637, -0.0534, 0.2885],\n", - " [ 0.0614, -0.1821, 0.1938, 0.2289, -0.5469, -0.1450, 0.2406, -0.0501],\n", - " [ 0.0779, -0.0984, -0.0306, 0.0446, -0.0793, 0.0205, -0.1166, 0.0932],\n", - " [ 0.0779, -0.0984, -0.0306, 0.0446, -0.0793, 0.0205, -0.1166, 0.0932],\n", - " [-0.2211, 0.7500, 0.1590, -0.0509, -0.2274, -0.3068, -0.3688, 0.3081],\n", - " [-0.1137, 0.3398, 0.2008, -0.2193, -0.2279, -0.2736, 0.0489, 0.1655],\n", - " [ 0.1418, 0.0633, 0.2066, 0.1069, -0.1668, -0.1283, 0.0725, 0.1086],\n", - " [ 0.0876, -0.1036, -0.0232, 0.0168, -0.0834, 0.0253, -0.1026, 0.1292],\n", - " [-0.1190, 0.6631, 0.4508, 0.1273, -0.8950, -0.2657, -0.1744, 0.1429],\n", - " [-0.0139, 1.1310, 0.6801, -0.0420, -0.8295, -0.3431, -0.3398, 0.1292],\n", - " [ 0.2009, -0.1070, 0.3390, -0.1786, -0.2455, 0.1952, -0.2296, 0.0448],\n", - " [ 0.0779, -0.0984, -0.0306, 0.0446, -0.0793, 0.0205, -0.1166, 0.0932],\n", - " [ 0.0114, 0.2627, 0.1608, 0.1347, -0.0517, -0.1514, -0.2966, 0.0293],\n", - " [ 0.0699, 0.1150, 0.0422, 0.0093, -0.1582, 0.0118, -0.2497, 0.1767],\n", - " [ 0.0779, -0.0984, -0.0306, 0.0446, -0.0793, 0.0205, -0.1166, 0.0932]],\n", + "tensor([[ 0.3583, 0.3327, 0.8323, 0.8916, 0.5396, 0.1304, 0.2440, 0.3201],\n", + " [-0.0438, -0.0836, 0.1281, -0.1373, -0.0662, -0.0799, 0.0397, -0.0843],\n", + " [-0.0406, -0.0849, 0.1393, -0.1362, -0.0570, -0.0761, 0.0385, -0.0778],\n", + " [-0.0275, -0.0903, 0.1866, -0.1316, -0.0180, -0.0597, 0.0335, -0.0503],\n", + " [-0.1317, -0.1861, 0.1497, -0.2301, -0.1974, -0.2250, -0.1069, -0.0058],\n", + " [-0.0438, -0.0836, 0.1281, -0.1373, -0.0662, -0.0799, 0.0397, -0.0843],\n", + " [ 0.4921, -0.1013, -0.1525, -0.5267, -0.4393, -0.2013, -0.2347, -0.1392],\n", + " [ 0.2613, 0.1353, 0.0309, -0.3266, -0.1960, -0.0995, -0.3109, -0.2549],\n", + " [-0.2865, 0.0438, 0.2934, -0.2267, -0.2556, -0.0191, 0.1033, 0.0397],\n", + " [-0.0438, -0.0836, 0.1281, -0.1373, -0.0662, -0.0799, 0.0397, -0.0843],\n", + " [-0.0365, -0.1025, 0.4465, -0.1549, -0.5510, 0.5609, -0.0412, -0.2608],\n", + " [-0.1204, -0.4502, 0.5061, -0.1478, -0.3524, 0.0705, 0.0807, -0.0759],\n", + " [-0.0144, -0.0733, 0.1701, -0.0786, -0.0807, -0.1167, -0.0150, -0.0425],\n", + " [-0.0438, -0.0836, 0.1281, -0.1373, -0.0662, -0.0799, 0.0397, -0.0843],\n", + " [ 0.3253, -0.3940, 0.1951, -0.1993, -0.1309, 0.1121, 0.0548, -0.2136],\n", + " [-0.0561, -0.0723, 0.1328, -0.1434, -0.0779, -0.0787, 0.0334, -0.0737],\n", + " [-0.0438, -0.0836, 0.1281, -0.1373, -0.0662, -0.0799, 0.0397, -0.0843]],\n", " grad_fn=<CatBackward0>)\n" ] } @@ -320,7 +430,7 @@ "linear_layer = nn.Linear(flattened_patches.size(1), hidden_d)\n", "classe_embedded = torch.rand((1,hidden_d))\n", "\n", - "features_emb=features_embedding(patches,linear_layer,classe_embedded)\n", + "features_emb=features_embedding(patches,linear_layer,classe_embedded,flattened_patches)\n", "print(features_emb)" ] }, @@ -430,6 +540,13 @@ "print(positional_emb)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "On a finalement notre Embeded Patches structuré pour entrer dans le Transformer Encoder." + ] + }, { "cell_type": "code", "execution_count": 124, @@ -539,6 +656,7 @@ " seq = sequence[:, head * self.d_head : (head + 1) * self.d_head]\n", " q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)\n", "\n", + " #On calcul l'attention A qui suit les relations décrites ci-dessous\n", " attention = torch.matmul(self.softmax(torch.matmul(q, k.t())/np.sqrt(int(self.d / self.n_heads))),v)\n", "\n", " seq_result.append(attention)\n", @@ -619,6 +737,7 @@ " )\n", "\n", " def forward(self, x):\n", + " # En reprenant la structure du Transformer on calcul la fonction forward.\n", " m1=x\n", " x = self.norm1(x)\n", " x = m1+self.mhsa(x)\n", @@ -705,6 +824,8 @@ "\n", " def forward(self, images):\n", "\n", + " #On refait les étapes présentés précedement\n", + "\n", " # Dividing images into patches\n", " n, c, h, w = images.shape\n", " patches = patchify(images,self.n_patches)\n", @@ -896,11 +1017,17 @@ " for batch in tqdm(train_loader, desc=f\"Epoch {epoch + 1}/{N_EPOCHS}\", leave=False):\n", " x, y = batch\n", " x, y = x.to(device), y.to(device)\n", + "\n", + " #initialisation du gradiant\n", " optimizer.zero_grad()\n", + "\n", " y_hat = model(x)\n", " loss = criterion(y_hat, y)\n", + "\n", + " #calcul du gradiant et evolution de l'optimizer\n", " loss.backward()\n", " optimizer.step()\n", + " \n", " train_loss += loss.detach().cpu().item() / len(train_loader)\n", " print(f\"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}\")\n" ] @@ -959,6 +1086,8 @@ " for batch in tqdm(test_loader, desc=f\"Batch avancement \", leave=False):\n", " x, y = batch\n", " x, y = x.to(device), y.to(device)\n", + "\n", + " #application du transformer et calcul de l'erreur.\n", " y_hat = model(x)\n", " loss = criterion(y_hat, y)\n", " _, pred = torch.max(y_hat, 1)\n", @@ -994,9 +1123,9 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python PyTorch 1.7.0", + "display_name": "Python 3", "language": "python", - "name": "pytorch-1.7.0" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -1008,7 +1137,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.10.12" } }, "nbformat": 4,