From 5eb82bea89388c55b1b9307bfa4d321c85dc3e7c Mon Sep 17 00:00:00 2001 From: youcef kessi <youcef.kessi@etu.univ-lyon1.fr> Date: Thu, 21 Nov 2024 09:59:11 +0100 Subject: [PATCH] modif --- TD2 Deep Learning.ipynb | 55 +++++++++++++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 7 deletions(-) diff --git a/TD2 Deep Learning.ipynb b/TD2 Deep Learning.ipynb index 29d74ba..72a622e 100644 --- a/TD2 Deep Learning.ipynb +++ b/TD2 Deep Learning.ipynb @@ -531,17 +531,37 @@ "source": [ "Yes, overfitting occurs.\n", "- Training loss steadily decreases throughout the epochs, reaching very low values.\n", - "- Validation loss decreases initially but starts to increase after epoch 17, suggesting that the model is overfitting to the training data and not generalizing well to the validation data.\n", - "\n", - "And the others indicators of this overfitting is the divergence between training and validation losses after a certain point and also validation loss starts to increase while the training loss continues to decrease" + "- Validation loss decreases initially but starts to increase after epoch 17 while the training loss continues to decrease, suggesting that the model is overfitting to the training data and not generalizing well to the validation data." + ] + }, + { + "cell_type": "markdown", + "id": "086dc438", + "metadata": {}, + "source": [ + "To address this, I can implement early stopping. \n", + "Given that the validation loss decreases until epoch 17, then rises again afterwards, it is clear that overfitting begins at this point.\n", + "So, i do an early stopping at 17 epochs." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "id": "d39df818", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'matplotlib'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[18], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mmatplotlib\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mpyplot\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mplt\u001b[39;00m\n\u001b[1;32m 3\u001b[0m plt\u001b[39m.\u001b[39mplot(\u001b[39mrange\u001b[39m(n_epochs), train_loss_list)\n\u001b[1;32m 4\u001b[0m plt\u001b[39m.\u001b[39mxlabel(\u001b[39m\"\u001b[39m\u001b[39mEpoch\u001b[39m\u001b[39m\"\u001b[39m)\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'matplotlib'" + ] + } + ], "source": [ "import matplotlib.pyplot as plt\n", "\n", @@ -562,10 +582,31 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "id": "e93efdfc", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test Loss: 33.489301\n", + "\n", + "Test Accuracy of airplane: 71% (714/1000)\n", + "Test Accuracy of automobile: 72% (727/1000)\n", + "Test Accuracy of bird: 46% (464/1000)\n", + "Test Accuracy of cat: 31% (312/1000)\n", + "Test Accuracy of deer: 57% (571/1000)\n", + "Test Accuracy of dog: 58% (584/1000)\n", + "Test Accuracy of frog: 64% (641/1000)\n", + "Test Accuracy of horse: 69% (692/1000)\n", + "Test Accuracy of ship: 61% (619/1000)\n", + "Test Accuracy of truck: 64% (645/1000)\n", + "\n", + "Test Accuracy (Overall): 59% (5969/10000)\n" + ] + } + ], "source": [ "model.load_state_dict(torch.load(\"./model_cifar.pt\"))\n", "\n", -- GitLab