From 3887ae5cf6153cf44b543b20965968a3a8505882 Mon Sep 17 00:00:00 2001 From: Emmanuel Dellandrea <emmanuel.dellandrea@ec-lyon.fr> Date: Tue, 25 Feb 2025 06:44:50 +0100 Subject: [PATCH] Update Subject_7_LLM.ipynb --- Practical_sessions/Session_7/Subject_7_LLM.ipynb | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/Practical_sessions/Session_7/Subject_7_LLM.ipynb b/Practical_sessions/Session_7/Subject_7_LLM.ipynb index 7101e1e..ac18981 100644 --- a/Practical_sessions/Session_7/Subject_7_LLM.ipynb +++ b/Practical_sessions/Session_7/Subject_7_LLM.ipynb @@ -703,7 +703,7 @@ "source": [ "model_save_path = \"custom_bert_model.pth\"\n", "\n", - "torch.save(model.state_dict(), model_save_path)\n" + "torch.save(model.state_dict(), model_save_path)" ] }, { @@ -718,8 +718,7 @@ "\n", "loaded_model.load_state_dict(torch.load(model_save_path))\n", "\n", - "loaded_model.to(device)\n", - "\n" + "loaded_model.to(device)\n" ] }, { @@ -777,11 +776,9 @@ " nn.Linear(config.hidden_size, 128),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", - " nn.Linear(128, config.num_labels) # Binary classification\n", + " nn.Linear(128, config.num_labels) \n", " )\n", - " self.init_weights()\n", "\n", - " # Freeze DistilBERT backbone if specified\n", " if freeze_backbone:\n", " for param in self.distilbert.parameters():\n", " param.requires_grad = False\n", @@ -789,8 +786,7 @@ " def forward(self, input_ids, attention_mask=None, labels=None):\n", " outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)\n", " logits = self.classifier(outputs.last_hidden_state[:, 0, :]) # Use [CLS] token output\n", - " return logits\n", - "\n" + " return logits\n" ] }, { @@ -812,15 +808,12 @@ "source": [ "from transformers import DistilBertTokenizer\n", "\n", - "# Initialize the configuration with custom attributes\n", "config = AutoConfig.from_pretrained(\"distilbert-base-uncased\", num_labels=2)\n", "config.architectures = [\"CustomDistilBERTModel\"]\n", "\n", - "# Initialize the model and tokenizer\n", "model = CustomDistilBERTModel(config)\n", "tokenizer = DistilBertTokenizer.from_pretrained(\"distilbert-base-uncased\")\n", "\n", - "# Save locally\n", "model.save_pretrained(\"custom_distilbert_model\")\n", "tokenizer.save_pretrained(\"custom_distilbert_model\")\n", "\n", -- GitLab