From c11b9d685e695fe45911cacf39a7378675a3cac2 Mon Sep 17 00:00:00 2001
From: Zhengfei ZHANG <zhengfei.zhang@ecl21.ec-lyon.fr>
Date: Thu, 21 Nov 2024 08:31:57 +0100
Subject: [PATCH] Added mps device selection for macOS

---
 TD2 Deep Learning.ipynb | 126 ++++++++++++++++++++++++++++++++++------
 1 file changed, 107 insertions(+), 19 deletions(-)

diff --git a/TD2 Deep Learning.ipynb b/TD2 Deep Learning.ipynb
index 00e4fdc..8236680 100644
--- a/TD2 Deep Learning.ipynb	
+++ b/TD2 Deep Learning.ipynb	
@@ -33,10 +33,31 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 1,
    "id": "330a42f5",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Requirement already satisfied: torch in /Users/zhangzhengfei/miniconda3/lib/python3.12/site-packages (2.5.1)\n",
+      "Requirement already satisfied: torchvision in /Users/zhangzhengfei/miniconda3/lib/python3.12/site-packages (0.20.1)\n",
+      "Requirement already satisfied: filelock in /Users/zhangzhengfei/miniconda3/lib/python3.12/site-packages (from torch) (3.13.1)\n",
+      "Requirement already satisfied: typing-extensions>=4.8.0 in /Users/zhangzhengfei/miniconda3/lib/python3.12/site-packages (from torch) (4.11.0)\n",
+      "Requirement already satisfied: networkx in /Users/zhangzhengfei/miniconda3/lib/python3.12/site-packages (from torch) (3.3)\n",
+      "Requirement already satisfied: jinja2 in /Users/zhangzhengfei/miniconda3/lib/python3.12/site-packages (from torch) (3.1.4)\n",
+      "Requirement already satisfied: fsspec in /Users/zhangzhengfei/miniconda3/lib/python3.12/site-packages (from torch) (2024.2.0)\n",
+      "Requirement already satisfied: setuptools in /Users/zhangzhengfei/miniconda3/lib/python3.12/site-packages (from torch) (69.5.1)\n",
+      "Requirement already satisfied: sympy==1.13.1 in /Users/zhangzhengfei/miniconda3/lib/python3.12/site-packages (from torch) (1.13.1)\n",
+      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /Users/zhangzhengfei/miniconda3/lib/python3.12/site-packages (from sympy==1.13.1->torch) (1.2.1)\n",
+      "Requirement already satisfied: numpy in /Users/zhangzhengfei/miniconda3/lib/python3.12/site-packages (from torchvision) (2.0.2)\n",
+      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /Users/zhangzhengfei/miniconda3/lib/python3.12/site-packages (from torchvision) (10.3.0)\n",
+      "Requirement already satisfied: MarkupSafe>=2.0 in /Users/zhangzhengfei/miniconda3/lib/python3.12/site-packages (from jinja2->torch) (2.1.5)\n",
+      "Note: you may need to restart the kernel to use updated packages.\n"
+     ]
+    }
+   ],
    "source": [
     "%pip install torch torchvision"
    ]
@@ -52,10 +73,72 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 2,
    "id": "b1950f0a",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "tensor([[ 0.2254,  0.5296, -0.3807, -0.9261, -0.3824, -0.2763, -0.0083, -0.7543,\n",
+      "          0.5702, -0.8143],\n",
+      "        [ 0.9931,  0.8214,  0.9046,  0.1916,  0.2341,  0.3015, -0.2166, -0.0472,\n",
+      "          1.0145, -1.3083],\n",
+      "        [ 0.3317,  1.8131,  0.0588, -0.0125,  0.4794, -1.1455,  0.5408, -1.0466,\n",
+      "          0.3021,  1.2281],\n",
+      "        [ 0.6281,  1.3689, -1.2237, -0.9255,  2.6865, -0.5821,  1.5674, -1.9834,\n",
+      "         -0.1872, -0.1995],\n",
+      "        [-0.2155,  0.0251,  0.8885, -1.3115, -0.4354,  0.4143,  0.7948, -0.3716,\n",
+      "          0.2296, -0.2655],\n",
+      "        [ 0.8752,  0.1450, -1.2205, -0.9925, -0.4455,  1.3189,  1.7200,  0.3739,\n",
+      "          0.0900,  1.5698],\n",
+      "        [ 0.7289,  0.6868,  0.3100, -1.9949,  1.5346,  0.4060, -0.5150,  0.5518,\n",
+      "         -0.3666, -1.3824],\n",
+      "        [-0.1809, -0.6672,  1.9024, -1.7088,  0.7907,  1.5970, -0.7753,  2.1105,\n",
+      "          0.0121,  0.0303],\n",
+      "        [ 0.2847, -0.8130,  0.3047,  0.7381,  1.4788,  1.5611, -0.2472, -1.2118,\n",
+      "         -1.2564, -0.3072],\n",
+      "        [-0.8530,  1.5380,  0.7398,  0.5787, -1.2414,  1.7822, -0.8175,  2.2730,\n",
+      "          2.0789,  1.1717],\n",
+      "        [ 0.9956,  0.9707, -0.7495,  2.7279, -0.2847, -0.3854, -0.1568,  0.5246,\n",
+      "         -0.3249,  0.8755],\n",
+      "        [ 0.5109,  0.2014, -0.6937, -1.7365,  2.0373, -0.3197, -0.2626,  1.0538,\n",
+      "         -0.6151,  0.3306],\n",
+      "        [ 1.3045,  0.3037,  0.6412,  0.4532, -1.2055, -0.0229,  0.3417,  0.9373,\n",
+      "         -0.1082, -0.0463],\n",
+      "        [-1.2099,  0.3418,  0.6566,  0.6314,  1.0552,  0.0662, -0.9298, -0.5887,\n",
+      "         -0.3136, -0.9497]])\n",
+      "AlexNet(\n",
+      "  (features): Sequential(\n",
+      "    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n",
+      "    (1): ReLU(inplace=True)\n",
+      "    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+      "    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
+      "    (4): ReLU(inplace=True)\n",
+      "    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+      "    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+      "    (7): ReLU(inplace=True)\n",
+      "    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+      "    (9): ReLU(inplace=True)\n",
+      "    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+      "    (11): ReLU(inplace=True)\n",
+      "    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+      "  )\n",
+      "  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n",
+      "  (classifier): Sequential(\n",
+      "    (0): Dropout(p=0.5, inplace=False)\n",
+      "    (1): Linear(in_features=9216, out_features=4096, bias=True)\n",
+      "    (2): ReLU(inplace=True)\n",
+      "    (3): Dropout(p=0.5, inplace=False)\n",
+      "    (4): Linear(in_features=4096, out_features=4096, bias=True)\n",
+      "    (5): ReLU(inplace=True)\n",
+      "    (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
+      "  )\n",
+      ")\n"
+     ]
+    }
+   ],
    "source": [
     "import torch\n",
     "\n",
@@ -95,20 +178,30 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 15,
    "id": "6e18f2fd",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Using device: mps.\n"
+     ]
+    }
+   ],
    "source": [
     "import torch\n",
     "\n",
-    "# check if CUDA is available\n",
-    "train_on_gpu = torch.cuda.is_available()\n",
+    "# Auto-detect the best device\n",
+    "if torch.backends.mps.is_built():  # MPS for macOS\n",
+    "    device = torch.device(\"mps\")\n",
+    "elif torch.cuda.is_available():  # CUDA for Windows/Linux\n",
+    "    device = torch.device(\"cuda\")\n",
+    "else:  # Default to CPU\n",
+    "    device = torch.device(\"cpu\")\n",
     "\n",
-    "if not train_on_gpu:\n",
-    "    print(\"CUDA is not available.  Training on CPU ...\")\n",
-    "else:\n",
-    "    print(\"CUDA is available!  Training on GPU ...\")"
+    "print(f\"Using device: {device}.\")"
    ]
   },
   {
@@ -926,7 +1019,7 @@
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3.8.5 ('base')",
+   "display_name": "base",
    "language": "python",
    "name": "python3"
   },
@@ -940,12 +1033,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.8.5"
-  },
-  "vscode": {
-   "interpreter": {
-    "hash": "9e3efbebb05da2d4a1968abe9a0645745f54b63feb7a85a514e4da0495be97eb"
-   }
+   "version": "3.12.2"
   }
  },
  "nbformat": 4,
-- 
GitLab