From 3887ae5cf6153cf44b543b20965968a3a8505882 Mon Sep 17 00:00:00 2001
From: Emmanuel Dellandrea <emmanuel.dellandrea@ec-lyon.fr>
Date: Tue, 25 Feb 2025 06:44:50 +0100
Subject: [PATCH] Update Subject_7_LLM.ipynb

---
 Practical_sessions/Session_7/Subject_7_LLM.ipynb | 15 ++++-----------
 1 file changed, 4 insertions(+), 11 deletions(-)

diff --git a/Practical_sessions/Session_7/Subject_7_LLM.ipynb b/Practical_sessions/Session_7/Subject_7_LLM.ipynb
index 7101e1e..ac18981 100644
--- a/Practical_sessions/Session_7/Subject_7_LLM.ipynb
+++ b/Practical_sessions/Session_7/Subject_7_LLM.ipynb
@@ -703,7 +703,7 @@
    "source": [
     "model_save_path = \"custom_bert_model.pth\"\n",
     "\n",
-    "torch.save(model.state_dict(), model_save_path)\n"
+    "torch.save(model.state_dict(), model_save_path)"
    ]
   },
   {
@@ -718,8 +718,7 @@
     "\n",
     "loaded_model.load_state_dict(torch.load(model_save_path))\n",
     "\n",
-    "loaded_model.to(device)\n",
-    "\n"
+    "loaded_model.to(device)\n"
    ]
   },
   {
@@ -777,11 +776,9 @@
     "            nn.Linear(config.hidden_size, 128),\n",
     "            nn.ReLU(),\n",
     "            nn.Dropout(0.1),\n",
-    "            nn.Linear(128, config.num_labels)  # Binary classification\n",
+    "            nn.Linear(128, config.num_labels)  \n",
     "        )\n",
-    "        self.init_weights()\n",
     "\n",
-    "        # Freeze DistilBERT backbone if specified\n",
     "        if freeze_backbone:\n",
     "            for param in self.distilbert.parameters():\n",
     "                param.requires_grad = False\n",
@@ -789,8 +786,7 @@
     "    def forward(self, input_ids, attention_mask=None, labels=None):\n",
     "        outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)\n",
     "        logits = self.classifier(outputs.last_hidden_state[:, 0, :])  # Use [CLS] token output\n",
-    "        return logits\n",
-    "\n"
+    "        return logits\n"
    ]
   },
   {
@@ -812,15 +808,12 @@
    "source": [
     "from transformers import DistilBertTokenizer\n",
     "\n",
-    "# Initialize the configuration with custom attributes\n",
     "config = AutoConfig.from_pretrained(\"distilbert-base-uncased\", num_labels=2)\n",
     "config.architectures = [\"CustomDistilBERTModel\"]\n",
     "\n",
-    "# Initialize the model and tokenizer\n",
     "model = CustomDistilBERTModel(config)\n",
     "tokenizer = DistilBertTokenizer.from_pretrained(\"distilbert-base-uncased\")\n",
     "\n",
-    "# Save locally\n",
     "model.save_pretrained(\"custom_distilbert_model\")\n",
     "tokenizer.save_pretrained(\"custom_distilbert_model\")\n",
     "\n",
-- 
GitLab