diff --git a/.ipynb_checkpoints/Rendu_TP1_GEREST-checkpoint.ipynb b/.ipynb_checkpoints/Rendu_TP1_GEREST-checkpoint.ipynb
index d05bc3577ae8b4d91d45b28ab935367fbaea9ff6..68a4f33337ff06254919d61721a04bff8b76a8b8 100644
--- a/.ipynb_checkpoints/Rendu_TP1_GEREST-checkpoint.ipynb
+++ b/.ipynb_checkpoints/Rendu_TP1_GEREST-checkpoint.ipynb
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "markdown",
-   "id": "e687aca8",
+   "id": "182bf66e",
    "metadata": {},
    "source": [
     "GEREST CORENTIN\n",
@@ -49,125 +49,46 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 1,
    "id": "c49497c0",
-   "metadata": {},
+   "metadata": {
+    "scrolled": false
+   },
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Requirement already satisfied: numpy==1.25.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (1.25.0)\n",
-      "Note: you may need to restart the kernel to use updated packages.\n"
-     ]
-    }
-   ],
-   "source": [
-    "pip install numpy==1.25.0 --user"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 4,
-   "id": "bb4d7c39",
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
       "Requirement already satisfied: gym==0.26.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (0.26.2)\n",
       "Requirement already satisfied: importlib-metadata>=4.8.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (4.8.1)\n",
-      "Requirement already satisfied: numpy>=1.18.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (1.25.0)\n",
-      "Requirement already satisfied: cloudpickle>=1.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (2.0.0)\n",
       "Requirement already satisfied: gym-notices>=0.0.4 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (0.0.8)\n",
+      "Requirement already satisfied: cloudpickle>=1.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (2.0.0)\n",
+      "Requirement already satisfied: numpy>=1.18.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (1.25.0)\n",
       "Requirement already satisfied: zipp>=0.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from importlib-metadata>=4.8.0->gym==0.26.2) (3.6.0)\n",
-      "Note: you may need to restart the kernel to use updated packages.\n"
-     ]
-    }
-   ],
-   "source": [
-    "pip install gym==0.26.2"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "c4339dd2",
-   "metadata": {},
-   "source": [
-    "Install also pyglet for the rendering."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "id": "ae74426f",
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
       "Requirement already satisfied: pyglet==2.0.10 in c:\\users\\coren\\anaconda3\\lib\\site-packages (2.0.10)\n",
-      "Note: you may need to restart the kernel to use updated packages.\n"
-     ]
-    }
-   ],
-   "source": [
-    "pip install pyglet==2.0.10"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "d6a2d90b",
-   "metadata": {},
-   "source": [
-    "If needed "
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 6,
-   "id": "712fb75a",
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
       "Requirement already satisfied: pygame==2.5.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (2.5.2)\n",
-      "Note: you may need to restart the kernel to use updated packages.\n"
-     ]
-    }
-   ],
-   "source": [
-    "pip install pygame==2.5.2"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 7,
-   "id": "3cdd7bcc",
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
       "Requirement already satisfied: PyQt5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (5.15.10)\n",
       "Requirement already satisfied: PyQt5-Qt5>=5.15.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from PyQt5) (5.15.2)\n",
-      "Note: you may need to restart the kernel to use updated packages.Requirement already satisfied: PyQt5-sip<13,>=12.13 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from PyQt5) (12.13.0)\n",
-      "\n"
+      "Requirement already satisfied: PyQt5-sip<13,>=12.13 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from PyQt5) (12.13.0)\n"
      ]
     }
    ],
    "source": [
-    "pip install PyQt5"
+    "!pip install numpy==1.25.0 --user\n",
+    "!pip install gym==0.26.2\n",
+    "\n",
+    "# Install also pyglet for the rendering.\n",
+    "!pip install pyglet==2.0.10\n",
+    "\n",
+    "# If needed\n",
+    "!pip install pygame==2.5.2\n",
+    "!pip install PyQt5"
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "79581e82",
+   "id": "51bd01b6",
    "metadata": {},
    "source": [
     "### Usage\n",
@@ -177,10 +98,23 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 2,
    "id": "800853bd",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "C:\\Users\\coren\\anaconda3\\lib\\site-packages\\numpy\\_distributor_init.py:30: UserWarning: loaded more than 1 DLL from .libs:\n",
+      "C:\\Users\\coren\\anaconda3\\lib\\site-packages\\numpy\\.libs\\libopenblas.4SP5SUA7CBGXUEOC35YP2ASOICYYEQZZ.gfortran-win_amd64.dll\n",
+      "C:\\Users\\coren\\anaconda3\\lib\\site-packages\\numpy\\.libs\\libopenblas64__v0.3.23-gcc_10_3_0.dll\n",
+      "  warnings.warn(\"loaded more than 1 DLL from .libs:\"\n",
+      "C:\\Users\\coren\\anaconda3\\lib\\site-packages\\gym\\utils\\passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)\n",
+      "  if not isinstance(terminated, (bool, np.bool8)):\n"
+     ]
+    }
+   ],
    "source": [
     "import gym\n",
     "\n",
@@ -208,7 +142,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "fea5e4cf",
+   "id": "5aa96f10",
    "metadata": {},
    "source": [
     "## REINFORCE\n",
@@ -236,8 +170,168 @@
     "To learn more about REINFORCE, you can refer to [this unit](https://huggingface.co/learn/deep-rl-course/unit4/introduction).\n",
     "\n",
     "> 🛠 **To be handed in**\n",
-    "> Use PyTorch to implement REINFORCE and solve the CartPole environement. Share the code in `reinforce_cartpole.py`, and share a plot showing the total reward accross episodes in the `README.md`.\n",
+    "> Use PyTorch to implement REINFORCE and solve the CartPole environement. Share the code in `reinforce_cartpole.py`, and share a plot showing the total reward accross episodes in the `README.md`."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "id": "f3e912fb",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c901dca8dc49499f91778a2439ab7250",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/500 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "C:\\Users\\coren\\AppData\\Local\\Temp/ipykernel_22904/2002979212.py:58: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\torch\\csrc\\utils\\tensor_new.cpp:264.)\n",
+      "  state_tensor = torch.FloatTensor(state).unsqueeze(0)\n"
+     ]
+    },
+    {
+     "ename": "ValueError",
+     "evalue": "expected sequence of length 4 at dim 1 (got 0)",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
+      "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_22904/2002979212.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     56\u001b[0m     \u001b[1;32mwhile\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     57\u001b[0m         \u001b[1;31m# Compute action probabilities\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 58\u001b[1;33m         \u001b[0mstate_tensor\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     59\u001b[0m         \u001b[0maction_probs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpolicy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate_tensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     60\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;31mValueError\u001b[0m: expected sequence of length 4 at dim 1 (got 0)"
+     ]
+    }
+   ],
+   "source": [
+    "import gym\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.optim as optim\n",
+    "import numpy as np\n",
+    "from tqdm.notebook import tqdm\n",
+    "\n",
+    "\n",
+    "# Define the neural network for the policy\n",
+    "class PolicyNetwork(nn.Module):\n",
+    "    def __init__(self, input_dim, output_dim):\n",
+    "        super(PolicyNetwork, self).__init__()\n",
+    "        self.fc1 = nn.Linear(input_dim, 128)\n",
+    "        self.relu = nn.ReLU()\n",
+    "        self.dropout = nn.Dropout(p=0.6)\n",
+    "        self.fc2 = nn.Linear(128, output_dim)\n",
+    "        self.softmax = nn.Softmax(dim=1)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        x = self.fc1(x)\n",
+    "        x = self.relu(x)\n",
+    "        x = self.dropout(x)\n",
+    "        x = self.fc2(x)\n",
+    "        return self.softmax(x)\n",
+    "\n",
+    "# Normalize function\n",
+    "def normalize_rewards(rewards):\n",
+    "    rewards = np.array(rewards)\n",
+    "    rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-9)\n",
+    "    return rewards\n",
+    "\n",
+    "# Hyperparameters\n",
+    "learning_rate = 5e-3\n",
+    "gamma = 0.99\n",
+    "episodes = 500\n",
+    "\n",
+    "# Environment setup\n",
+    "env = gym.make(\"CartPole-v1\", render_mode=\"human\")\n",
+    "\n",
+    "input_dim = env.observation_space.shape[0]\n",
+    "output_dim = env.action_space.n\n",
+    "\n",
+    "# Policy network\n",
+    "policy = PolicyNetwork(input_dim, output_dim)\n",
+    "optimizer = optim.Adam(policy.parameters(), lr=learning_rate)\n",
+    "\n",
+    "# Training loop\n",
+    "episode_rewards = []\n",
+    "for episode in tqdm(range(episodes)):\n",
+    "\n",
+    "    state = env.reset()\n",
+    "    episode_reward = 0\n",
+    "    saved_log_probs = []\n",
+    "    rewards = []\n",
+    "\n",
+    "    while True:\n",
+    "        # Compute action probabilities\n",
+    "        state_tensor = torch.FloatTensor(state).unsqueeze(0)\n",
+    "        action_probs = policy(state_tensor)\n",
     "\n",
+    "        # Sample action\n",
+    "        action = torch.multinomial(action_probs, 1).item()\n",
+    "        log_prob = torch.log(action_probs.squeeze(0)[action])\n",
+    "        saved_log_probs.append(log_prob)\n",
+    "\n",
+    "        # Step env with action\n",
+    "        next_state, reward, done, _ = env.step(action)\n",
+    "        rewards.append(reward)\n",
+    "        episode_reward += reward\n",
+    "\n",
+    "\n",
+    "        if done:\n",
+    "            # Compute return\n",
+    "            returns = []\n",
+    "            R = 0\n",
+    "            for r in rewards[::-1]:\n",
+    "                R = r + gamma * R\n",
+    "                returns.insert(0, R)\n",
+    "\n",
+    "            # Normalize returns\n",
+    "            returns = normalize_rewards(returns)\n",
+    "\n",
+    "            # Update policy\n",
+    "            policy_loss = torch.stack(saved_log_probs) * torch.tensor(returns)\n",
+    "            policy_loss = -policy_loss.sum()\n",
+    "            optimizer.zero_grad()\n",
+    "            policy_loss.backward()\n",
+    "            optimizer.step()\n",
+    "\n",
+    "            episode_rewards.append(episode_reward)\n",
+    "            break\n",
+    "\n",
+    "        state = next_state"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "06261130",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Plotting\n",
+    "import matplotlib.pyplot as plt\n",
+    "\n",
+    "plt.plot(episode_rewards)\n",
+    "plt.xlabel('Episode')\n",
+    "plt.ylabel('Total Reward')\n",
+    "plt.title('REINFORCE on CartPole')\n",
+    "plt.savefig('reinforce_cartpole_rewards2.png')\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "fea5e4cf",
+   "metadata": {},
+   "source": [
     "## Familiarization with a complete RL pipeline: Application to training a robotic arm\n",
     "\n",
     "In this section, you will use the Stable-Baselines3 package to train a robotic arm using RL. You'll get familiar with several widely-used tools for training, monitoring and sharing machine learning models.\n",
@@ -251,7 +345,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 4,
    "id": "7b5d4e63",
    "metadata": {
     "scrolled": true
@@ -261,50 +355,45 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Collecting stable-baselines3\n",
-      "  Using cached stable_baselines3-2.2.1-py3-none-any.whl (181 kB)\n",
+      "Requirement already satisfied: stable-baselines3 in c:\\users\\coren\\appdata\\roaming\\python\\python39\\site-packages (2.2.1)\n",
+      "Requirement already satisfied: gymnasium<0.30,>=0.28.1 in c:\\users\\coren\\appdata\\roaming\\python\\python39\\site-packages (from stable-baselines3) (0.29.1)\n",
       "Requirement already satisfied: matplotlib in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (3.4.3)\n",
-      "Requirement already satisfied: numpy>=1.20 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (1.26.4)\n",
+      "Requirement already satisfied: numpy>=1.20 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (1.25.0)\n",
       "Requirement already satisfied: pandas in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (1.3.4)\n",
-      "Requirement already satisfied: torch>=1.13 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (2.1.1)\n",
       "Requirement already satisfied: cloudpickle in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (2.0.0)\n",
-      "Collecting gymnasium<0.30,>=0.28.1\n",
-      "  Using cached gymnasium-0.29.1-py3-none-any.whl (953 kB)\n",
-      "Collecting farama-notifications>=0.0.1\n",
-      "  Using cached Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)\n",
-      "Requirement already satisfied: importlib-metadata>=4.8.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (4.8.1)\n",
+      "Requirement already satisfied: torch>=1.13 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (2.1.1)\n",
       "Requirement already satisfied: typing-extensions>=4.3.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (4.4.0)\n",
+      "Requirement already satisfied: importlib-metadata>=4.8.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (4.8.1)\n",
+      "Requirement already satisfied: farama-notifications>=0.0.1 in c:\\users\\coren\\appdata\\roaming\\python\\python39\\site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (0.0.4)\n",
       "Requirement already satisfied: zipp>=0.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from importlib-metadata>=4.8.0->gymnasium<0.30,>=0.28.1->stable-baselines3) (3.6.0)\n",
-      "Requirement already satisfied: jinja2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2.11.3)\n",
+      "Requirement already satisfied: filelock in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (3.3.1)\n",
       "Requirement already satisfied: sympy in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (1.9)\n",
-      "Requirement already satisfied: fsspec in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2021.10.1)\n",
       "Requirement already satisfied: networkx in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2.6.3)\n",
-      "Requirement already satisfied: filelock in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (3.3.1)\n",
+      "Requirement already satisfied: jinja2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2.11.3)\n",
+      "Requirement already satisfied: fsspec in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2021.10.1)\n",
       "Requirement already satisfied: MarkupSafe>=0.23 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from jinja2->torch>=1.13->stable-baselines3) (1.1.1)\n",
-      "Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (2.8.2)\n",
       "Requirement already satisfied: pillow>=6.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (8.4.0)\n",
-      "Requirement already satisfied: pyparsing>=2.2.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (3.0.4)\n",
       "Requirement already satisfied: kiwisolver>=1.0.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (1.3.1)\n",
+      "Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (2.8.2)\n",
+      "Requirement already satisfied: pyparsing>=2.2.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (3.0.4)\n",
       "Requirement already satisfied: cycler>=0.10 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (0.10.0)\n",
       "Requirement already satisfied: six in c:\\users\\coren\\anaconda3\\lib\\site-packages (from cycler>=0.10->matplotlib->stable-baselines3) (1.16.0)\n",
       "Requirement already satisfied: pytz>=2017.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from pandas->stable-baselines3) (2021.3)\n",
       "Requirement already satisfied: mpmath>=0.19 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from sympy->torch>=1.13->stable-baselines3) (1.2.1)\n",
-      "Installing collected packages: farama-notifications, gymnasium, stable-baselines3\n",
-      "Successfully installed farama-notifications-0.0.4 gymnasium-0.29.1 stable-baselines3-2.2.1\n",
       "Requirement already satisfied: moviepy in c:\\users\\coren\\anaconda3\\lib\\site-packages (1.0.3)\n",
-      "Requirement already satisfied: decorator<5.0,>=4.0.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (4.4.2)\n",
+      "Requirement already satisfied: imageio<3.0,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (2.9.0)\n",
       "Requirement already satisfied: requests<3.0,>=2.8.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (2.26.0)\n",
-      "Requirement already satisfied: imageio-ffmpeg>=0.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (0.4.9)\n",
-      "Requirement already satisfied: numpy>=1.17.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (1.26.4)\n",
       "Requirement already satisfied: proglog<=1.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (0.1.10)\n",
-      "Requirement already satisfied: imageio<3.0,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (2.9.0)\n",
       "Requirement already satisfied: tqdm<5.0,>=4.11.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (4.62.3)\n",
+      "Requirement already satisfied: imageio-ffmpeg>=0.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (0.4.9)\n",
+      "Requirement already satisfied: numpy>=1.17.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (1.25.0)\n",
+      "Requirement already satisfied: decorator<5.0,>=4.0.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (4.4.2)\n",
       "Requirement already satisfied: pillow in c:\\users\\coren\\anaconda3\\lib\\site-packages (from imageio<3.0,>=2.5->moviepy) (8.4.0)\n",
       "Requirement already satisfied: setuptools in c:\\users\\coren\\anaconda3\\lib\\site-packages (from imageio-ffmpeg>=0.2.0->moviepy) (58.0.4)\n",
-      "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (2021.10.8)\n",
       "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (3.2)\n",
-      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (1.26.7)\n",
+      "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (2021.10.8)\n",
       "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (2.0.4)\n",
+      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (1.26.7)\n",
       "Requirement already satisfied: colorama in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tqdm<5.0,>=4.11.2->moviepy) (0.4.4)\n"
      ]
     }
@@ -316,7 +405,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 5,
    "id": "14c1bc63",
    "metadata": {},
    "outputs": [],
@@ -332,7 +421,7 @@
     "\n",
     "vec_env = model.get_env()\n",
     "obs = vec_env.reset()\n",
-    "for i in range(1000):\n",
+    "for i in tqdm(range(1000)):\n",
     "    action, _state = model.predict(obs, deterministic=True)\n",
     "    obs, reward, done, info = vec_env.step(action)\n",
     "    #vec_env.render(\"human\")\n",
@@ -364,10 +453,54 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 6,
    "id": "cd890835",
-   "metadata": {},
-   "outputs": [],
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Collecting huggingface-sb3==2.3.1\n",
+      "  Downloading huggingface_sb3-2.3.1-py3-none-any.whl (9.5 kB)\n",
+      "Collecting huggingface-hub~=0.8\n",
+      "  Downloading huggingface_hub-0.20.3-py3-none-any.whl (330 kB)\n",
+      "Requirement already satisfied: pyyaml~=6.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-sb3==2.3.1) (6.0)\n",
+      "Collecting wasabi\n",
+      "  Downloading wasabi-1.1.2-py3-none-any.whl (27 kB)\n",
+      "Requirement already satisfied: cloudpickle>=1.6 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-sb3==2.3.1) (2.0.0)\n",
+      "Requirement already satisfied: numpy in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-sb3==2.3.1) (1.25.0)\n",
+      "Requirement already satisfied: packaging>=20.9 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (21.0)\n",
+      "Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (4.4.0)\n",
+      "Requirement already satisfied: requests in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (2.26.0)\n",
+      "Requirement already satisfied: filelock in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (3.3.1)\n",
+      "Requirement already satisfied: tqdm>=4.42.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (4.62.3)\n",
+      "Collecting fsspec>=2023.5.0\n",
+      "  Downloading fsspec-2024.2.0-py3-none-any.whl (170 kB)\n",
+      "Requirement already satisfied: pyparsing>=2.0.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from packaging>=20.9->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (3.0.4)\n",
+      "Requirement already satisfied: colorama in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tqdm>=4.42.1->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (0.4.4)\n",
+      "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (2021.10.8)\n",
+      "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (2.0.4)\n",
+      "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (3.2)\n",
+      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (1.26.7)\n",
+      "Collecting colorama\n",
+      "  Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)\n",
+      "Installing collected packages: colorama, fsspec, wasabi, huggingface-hub, huggingface-sb3\n",
+      "  Attempting uninstall: colorama\n",
+      "    Found existing installation: colorama 0.4.4\n",
+      "    Uninstalling colorama-0.4.4:\n",
+      "      Successfully uninstalled colorama-0.4.4\n",
+      "  Attempting uninstall: fsspec\n",
+      "    Found existing installation: fsspec 2021.10.1\n",
+      "    Uninstalling fsspec-2021.10.1:\n",
+      "      Successfully uninstalled fsspec-2021.10.1\n",
+      "Successfully installed colorama-0.4.6 fsspec-2024.2.0 huggingface-hub-0.20.3 huggingface-sb3-2.3.1 wasabi-1.1.2\n",
+      "Note: you may need to restart the kernel to use updated packages.\n"
+     ]
+    }
+   ],
    "source": [
     "pip install huggingface-sb3==2.3.1"
    ]
@@ -398,14 +531,238 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 7,
    "id": "6645c23a",
-   "metadata": {},
-   "outputs": [],
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Collecting wandb\n",
+      "  Downloading wandb-0.16.3-py3-none-any.whl (2.2 MB)\n",
+      "Collecting tensorboard\n",
+      "  Downloading tensorboard-2.16.2-py3-none-any.whl (5.5 MB)\n",
+      "Note: you may need to restart the kernel to use updated packages.\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+      "conda-repo-cli 1.0.4 requires pathlib, which is not installed.\n",
+      "anaconda-project 0.10.1 requires ruamel-yaml, which is not installed.\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Requirement already satisfied: requests<3,>=2.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (2.26.0)\n",
+      "Requirement already satisfied: setuptools in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (58.0.4)\n",
+      "Requirement already satisfied: appdirs>=1.4.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (1.4.4)\n",
+      "Requirement already satisfied: PyYAML in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (6.0)\n",
+      "Requirement already satisfied: typing-extensions in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (4.4.0)\n",
+      "Collecting GitPython!=3.1.29,>=1.0.0\n",
+      "  Downloading GitPython-3.1.42-py3-none-any.whl (195 kB)\n",
+      "Collecting sentry-sdk>=1.0.0\n",
+      "  Downloading sentry_sdk-1.40.5-py2.py3-none-any.whl (258 kB)\n",
+      "Requirement already satisfied: Click!=8.0.0,>=7.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (8.0.3)\n",
+      "Collecting setproctitle\n",
+      "  Downloading setproctitle-1.3.3-cp39-cp39-win_amd64.whl (11 kB)\n",
+      "Collecting protobuf!=4.21.0,<5,>=3.19.0\n",
+      "  Downloading protobuf-4.25.3-cp39-cp39-win_amd64.whl (413 kB)\n",
+      "Requirement already satisfied: psutil>=5.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (5.8.0)\n",
+      "Collecting docker-pycreds>=0.4.0\n",
+      "  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n",
+      "Requirement already satisfied: six>1.9 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tensorboard) (1.16.0)\n",
+      "Collecting tensorboard-data-server<0.8.0,>=0.7.0\n",
+      "  Downloading tensorboard_data_server-0.7.2-py3-none-any.whl (2.4 kB)\n",
+      "Requirement already satisfied: werkzeug>=1.0.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tensorboard) (2.0.2)\n",
+      "Collecting markdown>=2.6.8\n",
+      "  Downloading Markdown-3.5.2-py3-none-any.whl (103 kB)\n",
+      "Collecting absl-py>=0.4\n",
+      "  Downloading absl_py-2.1.0-py3-none-any.whl (133 kB)\n",
+      "Requirement already satisfied: numpy>=1.12.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tensorboard) (1.25.0)\n",
+      "Collecting grpcio>=1.48.2\n",
+      "  Downloading grpcio-1.62.0-cp39-cp39-win_amd64.whl (3.8 MB)\n",
+      "Requirement already satisfied: colorama in c:\\users\\coren\\anaconda3\\lib\\site-packages (from Click!=8.0.0,>=7.1->wandb) (0.4.6)\n",
+      "Collecting gitdb<5,>=4.0.1\n",
+      "  Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n",
+      "Collecting smmap<6,>=3.0.1\n",
+      "  Downloading smmap-5.0.1-py3-none-any.whl (24 kB)\n",
+      "Requirement already satisfied: importlib-metadata>=4.4 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from markdown>=2.6.8->tensorboard) (4.8.1)\n",
+      "Requirement already satisfied: zipp>=0.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard) (3.6.0)\n",
+      "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (2021.10.8)\n",
+      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (1.26.7)\n",
+      "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (2.0.4)\n",
+      "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (3.2)\n",
+      "Collecting urllib3<1.27,>=1.21.1\n",
+      "  Downloading urllib3-1.26.18-py2.py3-none-any.whl (143 kB)\n",
+      "Installing collected packages: smmap, urllib3, gitdb, tensorboard-data-server, setproctitle, sentry-sdk, protobuf, markdown, grpcio, GitPython, docker-pycreds, absl-py, wandb, tensorboard\n",
+      "  Attempting uninstall: urllib3\n",
+      "    Found existing installation: urllib3 1.26.7\n",
+      "    Uninstalling urllib3-1.26.7:\n",
+      "      Successfully uninstalled urllib3-1.26.7\n",
+      "Successfully installed GitPython-3.1.42 absl-py-2.1.0 docker-pycreds-0.4.0 gitdb-4.0.11 grpcio-1.62.0 markdown-3.5.2 protobuf-4.25.3 sentry-sdk-1.40.5 setproctitle-1.3.3 smmap-5.0.1 tensorboard-2.16.2 tensorboard-data-server-0.7.2 urllib3-1.26.18 wandb-0.16.3\n"
+     ]
+    }
+   ],
    "source": [
     "pip install wandb tensorboard"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "id": "f94466df",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\u001b[34m\u001b[1mwandb\u001b[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)\n",
+      "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n",
+      "\u001b[34m\u001b[1mwandb\u001b[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:\n",
+      "Traceback (most recent call last):\n",
+      "  File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\", line 1172, in init\n",
+      "    wi.setup(kwargs)\n",
+      "  File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\", line 306, in setup\n",
+      "    wandb_login._login(\n",
+      "  File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\", line 317, in _login\n",
+      "    wlogin.prompt_api_key()\n",
+      "  File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\", line 240, in prompt_api_key\n",
+      "    key, status = self._prompt_api_key()\n",
+      "  File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\", line 220, in _prompt_api_key\n",
+      "    key = apikey.prompt_api_key(\n",
+      "  File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\lib\\apikey.py\", line 151, in prompt_api_key\n",
+      "    key = input_callback(api_ask).strip()\n",
+      "  File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\click\\termui.py\", line 168, in prompt\n",
+      "    value = prompt_func(prompt)\n",
+      "  File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\click\\termui.py\", line 150, in prompt_func\n",
+      "    raise Abort() from None\n",
+      "click.exceptions.Abort\n"
+     ]
+    },
+    {
+     "ename": "Error",
+     "evalue": "An unexpected error occurred",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[1;31mAbort\u001b[0m                                     Traceback (most recent call last)",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\u001b[0m in \u001b[0;36minit\u001b[1;34m(job_type, dir, config, project, entity, reinit, tags, group, name, notes, magic, config_exclude_keys, config_include_keys, anonymous, mode, allow_val_change, resume, force, tensorboard, sync_tensorboard, monitor_gym, save_code, id, settings)\u001b[0m\n\u001b[0;32m   1171\u001b[0m         \u001b[0mwi\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_WandbInit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1172\u001b[1;33m         \u001b[0mwi\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msetup\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1173\u001b[0m         \u001b[1;32massert\u001b[0m \u001b[0mwi\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msettings\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\u001b[0m in \u001b[0;36msetup\u001b[1;34m(self, kwargs)\u001b[0m\n\u001b[0;32m    305\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0msettings\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_offline\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0msettings\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_noop\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 306\u001b[1;33m             wandb_login._login(\n\u001b[0m\u001b[0;32m    307\u001b[0m                 \u001b[0manonymous\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"anonymous\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\u001b[0m in \u001b[0;36m_login\u001b[1;34m(anonymous, key, relogin, host, force, timeout, _backend, _silent, _disable_warning, _entity)\u001b[0m\n\u001b[0;32m    316\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mkey\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 317\u001b[1;33m         \u001b[0mwlogin\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mprompt_api_key\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    318\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\u001b[0m in \u001b[0;36mprompt_api_key\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    239\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mprompt_api_key\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 240\u001b[1;33m         \u001b[0mkey\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstatus\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_prompt_api_key\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    241\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mstatus\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mApiKeyStatus\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mNOTTY\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\u001b[0m in \u001b[0;36m_prompt_api_key\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    219\u001b[0m             \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 220\u001b[1;33m                 key = apikey.prompt_api_key(\n\u001b[0m\u001b[0;32m    221\u001b[0m                     \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_settings\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\lib\\apikey.py\u001b[0m in \u001b[0;36mprompt_api_key\u001b[1;34m(settings, api, input_callback, browser_callback, no_offline, no_create, local)\u001b[0m\n\u001b[0;32m    150\u001b[0m             )\n\u001b[1;32m--> 151\u001b[1;33m             \u001b[0mkey\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minput_callback\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mapi_ask\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstrip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    152\u001b[0m         \u001b[0mwrite_key\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msettings\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mapi\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mapi\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\click\\termui.py\u001b[0m in \u001b[0;36mprompt\u001b[1;34m(text, default, hide_input, confirmation_prompt, type, value_proc, prompt_suffix, show_default, err, show_choices)\u001b[0m\n\u001b[0;32m    167\u001b[0m         \u001b[1;32mwhile\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 168\u001b[1;33m             \u001b[0mvalue\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mprompt_func\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprompt\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    169\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mvalue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\click\\termui.py\u001b[0m in \u001b[0;36mprompt_func\u001b[1;34m(text)\u001b[0m\n\u001b[0;32m    149\u001b[0m                 \u001b[0mecho\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0merr\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0merr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 150\u001b[1;33m             \u001b[1;32mraise\u001b[0m \u001b[0mAbort\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    151\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;31mAbort\u001b[0m: ",
+      "\nThe above exception was the direct cause of the following exception:\n",
+      "\u001b[1;31mError\u001b[0m                                     Traceback (most recent call last)",
+      "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_22904/2820517193.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     12\u001b[0m     \u001b[1;34m\"env_name\"\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;34m\"CartPole-v1\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     13\u001b[0m }\n\u001b[1;32m---> 14\u001b[1;33m run = wandb.init(\n\u001b[0m\u001b[0;32m     15\u001b[0m     \u001b[0mproject\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m\"sb3\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     16\u001b[0m     \u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\u001b[0m in \u001b[0;36minit\u001b[1;34m(job_type, dir, config, project, entity, reinit, tags, group, name, notes, magic, config_exclude_keys, config_include_keys, anonymous, mode, allow_val_change, resume, force, tensorboard, sync_tensorboard, monitor_gym, save_code, id, settings)\u001b[0m\n\u001b[0;32m   1212\u001b[0m                 \u001b[0mwandb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtermerror\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Abnormal program exit\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1213\u001b[0m                 \u001b[0mos\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_exit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1214\u001b[1;33m             \u001b[1;32mraise\u001b[0m \u001b[0mError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"An unexpected error occurred\"\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0merror_seen\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1215\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0mrun\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;31mError\u001b[0m: An unexpected error occurred"
+     ]
+    }
+   ],
+   "source": [
+    "import gym\n",
+    "from stable_baselines3 import PPO\n",
+    "from stable_baselines3.common.monitor import Monitor\n",
+    "from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder\n",
+    "import wandb\n",
+    "from wandb.integration.sb3 import WandbCallback\n",
+    "\n",
+    "\n",
+    "config = {\n",
+    "    \"policy_type\": \"MlpPolicy\",\n",
+    "    \"total_timesteps\": 25000,\n",
+    "    \"env_name\": \"CartPole-v1\",\n",
+    "}\n",
+    "run = wandb.init(\n",
+    "    project=\"sb3\",\n",
+    "    config=config,\n",
+    "    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics\n",
+    "    monitor_gym=True,  # auto-upload the videos of agents playing the game\n",
+    "    save_code=True,  # optional\n",
+    ")\n",
+    "\n",
+    "\n",
+    "def make_env():\n",
+    "    env = gym.make(config[\"env_name\"])\n",
+    "    env = Monitor(env)  # record stats such as returns\n",
+    "    return env\n",
+    "\n",
+    "\n",
+    "env = DummyVecEnv([make_env])\n",
+    "env = VecVideoRecorder(\n",
+    "    env,\n",
+    "    f\"videos/{run.id}\",\n",
+    "    record_video_trigger=lambda x: x % 2000 == 0,\n",
+    "    video_length=200,\n",
+    ")\n",
+    "model = PPO(config[\"policy_type\"], env, verbose=1, tensorboard_log=f\"runs/{run.id}\")\n",
+    "model.learn(\n",
+    "    total_timesteps=config[\"total_timesteps\"],\n",
+    "    callback=WandbCallback(\n",
+    "        gradient_save_freq=100,\n",
+    "        model_save_path=f\"models/{run.id}\",\n",
+    "        verbose=2,\n",
+    "    ),\n",
+    ")\n",
+    "run.finish()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "id": "4dda548e",
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "Error",
+     "evalue": "You must call wandb.init() before WandbCallback()",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[1;31mError\u001b[0m                                     Traceback (most recent call last)",
+      "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_22904/2131705787.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      8\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      9\u001b[0m \u001b[0mmodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mA2C\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"MlpPolicy\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 10\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlearn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtotal_timesteps\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m10_000\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mWandbCallback\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     11\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     12\u001b[0m \u001b[0mvec_env\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_env\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\integration\\sb3\\sb3.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, verbose, model_save_path, model_save_freq, gradient_save_freq, log)\u001b[0m\n\u001b[0;32m     95\u001b[0m         \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mverbose\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     96\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mwandb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 97\u001b[1;33m             \u001b[1;32mraise\u001b[0m \u001b[0mwandb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"You must call wandb.init() before WandbCallback()\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     98\u001b[0m         \u001b[1;32mwith\u001b[0m \u001b[0mwb_telemetry\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcontext\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mtel\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     99\u001b[0m             \u001b[0mtel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfeature\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msb3\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;31mError\u001b[0m: You must call wandb.init() before WandbCallback()"
+     ]
+    }
+   ],
+   "source": [
+    "from wandb.integration.sb3 import WandbCallback\n",
+    "\n",
+    "import gymnasium as gym\n",
+    "from stable_baselines3 import A2C\n",
+    "\n",
+    "\n",
+    "env = gym.make(\"CartPole-v1\", render_mode=\"human\")\n",
+    "\n",
+    "model = A2C(\"MlpPolicy\", env, verbose=0)\n",
+    "model.learn(total_timesteps=10_000, callback=WandbCallback())\n",
+    "\n",
+    "vec_env = model.get_env()\n",
+    "obs = vec_env.reset()\n",
+    "for i in tqdm(range(1000)):\n",
+    "    action, _state = model.predict(obs, deterministic=True)\n",
+    "    obs, reward, done, info = vec_env.step(action)\n",
+    "    #vec_env.render(\"human\")\n",
+    "    # VecEnv resets automatically\n",
+    "    # if done:\n",
+    "    #   obs = vec_env.reset()\n",
+    "env.close()"
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "9af0d167",
diff --git a/Rendu_TP1_GEREST.ipynb b/Rendu_TP1_GEREST.ipynb
index d05bc3577ae8b4d91d45b28ab935367fbaea9ff6..774b45bffb8eb61579386aa9b9034892469139d2 100644
--- a/Rendu_TP1_GEREST.ipynb
+++ b/Rendu_TP1_GEREST.ipynb
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "markdown",
-   "id": "e687aca8",
+   "id": "182bf66e",
    "metadata": {},
    "source": [
     "GEREST CORENTIN\n",
@@ -49,125 +49,46 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 1,
    "id": "c49497c0",
-   "metadata": {},
+   "metadata": {
+    "scrolled": false
+   },
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Requirement already satisfied: numpy==1.25.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (1.25.0)\n",
-      "Note: you may need to restart the kernel to use updated packages.\n"
-     ]
-    }
-   ],
-   "source": [
-    "pip install numpy==1.25.0 --user"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 4,
-   "id": "bb4d7c39",
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
       "Requirement already satisfied: gym==0.26.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (0.26.2)\n",
       "Requirement already satisfied: importlib-metadata>=4.8.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (4.8.1)\n",
-      "Requirement already satisfied: numpy>=1.18.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (1.25.0)\n",
-      "Requirement already satisfied: cloudpickle>=1.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (2.0.0)\n",
       "Requirement already satisfied: gym-notices>=0.0.4 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (0.0.8)\n",
+      "Requirement already satisfied: cloudpickle>=1.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (2.0.0)\n",
+      "Requirement already satisfied: numpy>=1.18.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (1.25.0)\n",
       "Requirement already satisfied: zipp>=0.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from importlib-metadata>=4.8.0->gym==0.26.2) (3.6.0)\n",
-      "Note: you may need to restart the kernel to use updated packages.\n"
-     ]
-    }
-   ],
-   "source": [
-    "pip install gym==0.26.2"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "c4339dd2",
-   "metadata": {},
-   "source": [
-    "Install also pyglet for the rendering."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "id": "ae74426f",
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
       "Requirement already satisfied: pyglet==2.0.10 in c:\\users\\coren\\anaconda3\\lib\\site-packages (2.0.10)\n",
-      "Note: you may need to restart the kernel to use updated packages.\n"
-     ]
-    }
-   ],
-   "source": [
-    "pip install pyglet==2.0.10"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "d6a2d90b",
-   "metadata": {},
-   "source": [
-    "If needed "
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 6,
-   "id": "712fb75a",
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
       "Requirement already satisfied: pygame==2.5.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (2.5.2)\n",
-      "Note: you may need to restart the kernel to use updated packages.\n"
-     ]
-    }
-   ],
-   "source": [
-    "pip install pygame==2.5.2"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 7,
-   "id": "3cdd7bcc",
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
       "Requirement already satisfied: PyQt5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (5.15.10)\n",
       "Requirement already satisfied: PyQt5-Qt5>=5.15.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from PyQt5) (5.15.2)\n",
-      "Note: you may need to restart the kernel to use updated packages.Requirement already satisfied: PyQt5-sip<13,>=12.13 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from PyQt5) (12.13.0)\n",
-      "\n"
+      "Requirement already satisfied: PyQt5-sip<13,>=12.13 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from PyQt5) (12.13.0)\n"
      ]
     }
    ],
    "source": [
-    "pip install PyQt5"
+    "!pip install numpy==1.25.0 --user\n",
+    "!pip install gym==0.26.2\n",
+    "\n",
+    "# Install also pyglet for the rendering.\n",
+    "!pip install pyglet==2.0.10\n",
+    "\n",
+    "# If needed\n",
+    "!pip install pygame==2.5.2\n",
+    "!pip install PyQt5"
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "79581e82",
+   "id": "51bd01b6",
    "metadata": {},
    "source": [
     "### Usage\n",
@@ -177,10 +98,23 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 2,
    "id": "800853bd",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "C:\\Users\\coren\\anaconda3\\lib\\site-packages\\numpy\\_distributor_init.py:30: UserWarning: loaded more than 1 DLL from .libs:\n",
+      "C:\\Users\\coren\\anaconda3\\lib\\site-packages\\numpy\\.libs\\libopenblas.4SP5SUA7CBGXUEOC35YP2ASOICYYEQZZ.gfortran-win_amd64.dll\n",
+      "C:\\Users\\coren\\anaconda3\\lib\\site-packages\\numpy\\.libs\\libopenblas64__v0.3.23-gcc_10_3_0.dll\n",
+      "  warnings.warn(\"loaded more than 1 DLL from .libs:\"\n",
+      "C:\\Users\\coren\\anaconda3\\lib\\site-packages\\gym\\utils\\passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)\n",
+      "  if not isinstance(terminated, (bool, np.bool8)):\n"
+     ]
+    }
+   ],
    "source": [
     "import gym\n",
     "\n",
@@ -208,7 +142,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "fea5e4cf",
+   "id": "5aa96f10",
    "metadata": {},
    "source": [
     "## REINFORCE\n",
@@ -236,8 +170,168 @@
     "To learn more about REINFORCE, you can refer to [this unit](https://huggingface.co/learn/deep-rl-course/unit4/introduction).\n",
     "\n",
     "> 🛠 **To be handed in**\n",
-    "> Use PyTorch to implement REINFORCE and solve the CartPole environement. Share the code in `reinforce_cartpole.py`, and share a plot showing the total reward accross episodes in the `README.md`.\n",
+    "> Use PyTorch to implement REINFORCE and solve the CartPole environement. Share the code in `reinforce_cartpole.py`, and share a plot showing the total reward accross episodes in the `README.md`."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "id": "f3e912fb",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c901dca8dc49499f91778a2439ab7250",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/500 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "C:\\Users\\coren\\AppData\\Local\\Temp/ipykernel_22904/2002979212.py:58: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\torch\\csrc\\utils\\tensor_new.cpp:264.)\n",
+      "  state_tensor = torch.FloatTensor(state).unsqueeze(0)\n"
+     ]
+    },
+    {
+     "ename": "ValueError",
+     "evalue": "expected sequence of length 4 at dim 1 (got 0)",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
+      "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_22904/2002979212.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     56\u001b[0m     \u001b[1;32mwhile\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     57\u001b[0m         \u001b[1;31m# Compute action probabilities\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 58\u001b[1;33m         \u001b[0mstate_tensor\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     59\u001b[0m         \u001b[0maction_probs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpolicy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate_tensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     60\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;31mValueError\u001b[0m: expected sequence of length 4 at dim 1 (got 0)"
+     ]
+    }
+   ],
+   "source": [
+    "import gym\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.optim as optim\n",
+    "import numpy as np\n",
+    "from tqdm.notebook import tqdm\n",
+    "\n",
+    "\n",
+    "# Define the neural network for the policy\n",
+    "class PolicyNetwork(nn.Module):\n",
+    "    def __init__(self, input_dim, output_dim):\n",
+    "        super(PolicyNetwork, self).__init__()\n",
+    "        self.fc1 = nn.Linear(input_dim, 128)\n",
+    "        self.relu = nn.ReLU()\n",
+    "        self.dropout = nn.Dropout(p=0.6)\n",
+    "        self.fc2 = nn.Linear(128, output_dim)\n",
+    "        self.softmax = nn.Softmax(dim=1)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        x = self.fc1(x)\n",
+    "        x = self.relu(x)\n",
+    "        x = self.dropout(x)\n",
+    "        x = self.fc2(x)\n",
+    "        return self.softmax(x)\n",
+    "\n",
+    "# Normalize function\n",
+    "def normalize_rewards(rewards):\n",
+    "    rewards = np.array(rewards)\n",
+    "    rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-9)\n",
+    "    return rewards\n",
+    "\n",
+    "# Hyperparameters\n",
+    "learning_rate = 5e-3\n",
+    "gamma = 0.99\n",
+    "episodes = 500\n",
+    "\n",
+    "# Environment setup\n",
+    "env = gym.make(\"CartPole-v1\", render_mode=\"human\")\n",
+    "\n",
+    "input_dim = env.observation_space.shape[0]\n",
+    "output_dim = env.action_space.n\n",
+    "\n",
+    "# Policy network\n",
+    "policy = PolicyNetwork(input_dim, output_dim)\n",
+    "optimizer = optim.Adam(policy.parameters(), lr=learning_rate)\n",
+    "\n",
+    "# Training loop\n",
+    "episode_rewards = []\n",
+    "for episode in tqdm(range(episodes)):\n",
+    "\n",
+    "    state = env.reset()\n",
+    "    episode_reward = 0\n",
+    "    saved_log_probs = []\n",
+    "    rewards = []\n",
+    "\n",
+    "    while True:\n",
+    "        # Compute action probabilities\n",
+    "        state_tensor = torch.FloatTensor(state).unsqueeze(0)\n",
+    "        action_probs = policy(state_tensor)\n",
+    "\n",
+    "        # Sample action\n",
+    "        action = torch.multinomial(action_probs, 1).item()\n",
+    "        log_prob = torch.log(action_probs.squeeze(0)[action])\n",
+    "        saved_log_probs.append(log_prob)\n",
+    "\n",
+    "        # Step env with action\n",
+    "        next_state, reward, done, _ = env.step(action)\n",
+    "        rewards.append(reward)\n",
+    "        episode_reward += reward\n",
+    "\n",
+    "\n",
+    "        if done:\n",
+    "            # Compute return\n",
+    "            returns = []\n",
+    "            R = 0\n",
+    "            for r in rewards[::-1]:\n",
+    "                R = r + gamma * R\n",
+    "                returns.insert(0, R)\n",
     "\n",
+    "            # Normalize returns\n",
+    "            returns = normalize_rewards(returns)\n",
+    "\n",
+    "            # Update policy\n",
+    "            policy_loss = torch.stack(saved_log_probs) * torch.tensor(returns)\n",
+    "            policy_loss = -policy_loss.sum()\n",
+    "            optimizer.zero_grad()\n",
+    "            policy_loss.backward()\n",
+    "            optimizer.step()\n",
+    "\n",
+    "            episode_rewards.append(episode_reward)\n",
+    "            break\n",
+    "\n",
+    "        state = next_state"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "06261130",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Plotting\n",
+    "import matplotlib.pyplot as plt\n",
+    "\n",
+    "plt.plot(episode_rewards)\n",
+    "plt.xlabel('Episode')\n",
+    "plt.ylabel('Total Reward')\n",
+    "plt.title('REINFORCE on CartPole')\n",
+    "plt.savefig('reinforce_cartpole_rewards2.png')\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "fea5e4cf",
+   "metadata": {},
+   "source": [
     "## Familiarization with a complete RL pipeline: Application to training a robotic arm\n",
     "\n",
     "In this section, you will use the Stable-Baselines3 package to train a robotic arm using RL. You'll get familiar with several widely-used tools for training, monitoring and sharing machine learning models.\n",
@@ -251,7 +345,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 4,
    "id": "7b5d4e63",
    "metadata": {
     "scrolled": true
@@ -261,50 +355,45 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Collecting stable-baselines3\n",
-      "  Using cached stable_baselines3-2.2.1-py3-none-any.whl (181 kB)\n",
+      "Requirement already satisfied: stable-baselines3 in c:\\users\\coren\\appdata\\roaming\\python\\python39\\site-packages (2.2.1)\n",
+      "Requirement already satisfied: gymnasium<0.30,>=0.28.1 in c:\\users\\coren\\appdata\\roaming\\python\\python39\\site-packages (from stable-baselines3) (0.29.1)\n",
       "Requirement already satisfied: matplotlib in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (3.4.3)\n",
-      "Requirement already satisfied: numpy>=1.20 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (1.26.4)\n",
+      "Requirement already satisfied: numpy>=1.20 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (1.25.0)\n",
       "Requirement already satisfied: pandas in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (1.3.4)\n",
-      "Requirement already satisfied: torch>=1.13 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (2.1.1)\n",
       "Requirement already satisfied: cloudpickle in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (2.0.0)\n",
-      "Collecting gymnasium<0.30,>=0.28.1\n",
-      "  Using cached gymnasium-0.29.1-py3-none-any.whl (953 kB)\n",
-      "Collecting farama-notifications>=0.0.1\n",
-      "  Using cached Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)\n",
-      "Requirement already satisfied: importlib-metadata>=4.8.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (4.8.1)\n",
+      "Requirement already satisfied: torch>=1.13 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (2.1.1)\n",
       "Requirement already satisfied: typing-extensions>=4.3.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (4.4.0)\n",
+      "Requirement already satisfied: importlib-metadata>=4.8.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (4.8.1)\n",
+      "Requirement already satisfied: farama-notifications>=0.0.1 in c:\\users\\coren\\appdata\\roaming\\python\\python39\\site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (0.0.4)\n",
       "Requirement already satisfied: zipp>=0.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from importlib-metadata>=4.8.0->gymnasium<0.30,>=0.28.1->stable-baselines3) (3.6.0)\n",
-      "Requirement already satisfied: jinja2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2.11.3)\n",
+      "Requirement already satisfied: filelock in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (3.3.1)\n",
       "Requirement already satisfied: sympy in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (1.9)\n",
-      "Requirement already satisfied: fsspec in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2021.10.1)\n",
       "Requirement already satisfied: networkx in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2.6.3)\n",
-      "Requirement already satisfied: filelock in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (3.3.1)\n",
+      "Requirement already satisfied: jinja2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2.11.3)\n",
+      "Requirement already satisfied: fsspec in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2021.10.1)\n",
       "Requirement already satisfied: MarkupSafe>=0.23 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from jinja2->torch>=1.13->stable-baselines3) (1.1.1)\n",
-      "Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (2.8.2)\n",
       "Requirement already satisfied: pillow>=6.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (8.4.0)\n",
-      "Requirement already satisfied: pyparsing>=2.2.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (3.0.4)\n",
       "Requirement already satisfied: kiwisolver>=1.0.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (1.3.1)\n",
+      "Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (2.8.2)\n",
+      "Requirement already satisfied: pyparsing>=2.2.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (3.0.4)\n",
       "Requirement already satisfied: cycler>=0.10 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (0.10.0)\n",
       "Requirement already satisfied: six in c:\\users\\coren\\anaconda3\\lib\\site-packages (from cycler>=0.10->matplotlib->stable-baselines3) (1.16.0)\n",
       "Requirement already satisfied: pytz>=2017.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from pandas->stable-baselines3) (2021.3)\n",
       "Requirement already satisfied: mpmath>=0.19 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from sympy->torch>=1.13->stable-baselines3) (1.2.1)\n",
-      "Installing collected packages: farama-notifications, gymnasium, stable-baselines3\n",
-      "Successfully installed farama-notifications-0.0.4 gymnasium-0.29.1 stable-baselines3-2.2.1\n",
       "Requirement already satisfied: moviepy in c:\\users\\coren\\anaconda3\\lib\\site-packages (1.0.3)\n",
-      "Requirement already satisfied: decorator<5.0,>=4.0.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (4.4.2)\n",
+      "Requirement already satisfied: imageio<3.0,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (2.9.0)\n",
       "Requirement already satisfied: requests<3.0,>=2.8.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (2.26.0)\n",
-      "Requirement already satisfied: imageio-ffmpeg>=0.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (0.4.9)\n",
-      "Requirement already satisfied: numpy>=1.17.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (1.26.4)\n",
       "Requirement already satisfied: proglog<=1.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (0.1.10)\n",
-      "Requirement already satisfied: imageio<3.0,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (2.9.0)\n",
       "Requirement already satisfied: tqdm<5.0,>=4.11.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (4.62.3)\n",
+      "Requirement already satisfied: imageio-ffmpeg>=0.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (0.4.9)\n",
+      "Requirement already satisfied: numpy>=1.17.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (1.25.0)\n",
+      "Requirement already satisfied: decorator<5.0,>=4.0.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (4.4.2)\n",
       "Requirement already satisfied: pillow in c:\\users\\coren\\anaconda3\\lib\\site-packages (from imageio<3.0,>=2.5->moviepy) (8.4.0)\n",
       "Requirement already satisfied: setuptools in c:\\users\\coren\\anaconda3\\lib\\site-packages (from imageio-ffmpeg>=0.2.0->moviepy) (58.0.4)\n",
-      "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (2021.10.8)\n",
       "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (3.2)\n",
-      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (1.26.7)\n",
+      "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (2021.10.8)\n",
       "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (2.0.4)\n",
+      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (1.26.7)\n",
       "Requirement already satisfied: colorama in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tqdm<5.0,>=4.11.2->moviepy) (0.4.4)\n"
      ]
     }
@@ -316,30 +405,212 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 15,
    "id": "14c1bc63",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ac2e8543ce234f088dce3fec9f63d59a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/10000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD4CAYAAAANbUbJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAO7UlEQVR4nO3cYazddX3H8fdnvVQE51rXO1Pb6i1J42yME3bDihpDxG2FGUmMD2jCcETTmIFDt8SAPjB7posxSmZgjVbHdKBD3BrChgtqyB4I3ApioXRWUHulrteYgdEsUP3uwflXT6733nN6e8qh5/d+JTc9/9/vf879/dry7jn/cy6pKiRJk+23xr0ASdLpZ+wlqQHGXpIaYOwlqQHGXpIaMDXuBSxlw4YNNTMzM+5lSNIZY//+/T+uqunl5p+XsZ+ZmWFubm7cy5CkM0aS768072UcSWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWrAwNgn2ZvkWJIDy8wnyY1JDid5OMkFi+bXJHkwyZ2jWrQk6eQM88z+s8DOFeYvBbZ1X7uBmxbNXwccXM3iJEmjMTD2VXUv8JMVTrkcuKV6vgGsS7IRIMlm4M+AT41isZKk1RnFNftNwJG+4/luDODjwPuBXw56kCS7k8wlmVtYWBjBsiRJJ4wi9llirJK8BThWVfuHeZCq2lNVs1U1Oz09PYJlSZJOGEXs54EtfcebgSeB1wNvTfI94DbgTUk+N4LvJ0k6SaOI/T7gqu5TOTuAp6rqaFXdUFWbq2oGuAL4alVdOYLvJ0k6SVODTkhyK3AxsCHJPPAh4CyAqroZuAu4DDgM/By4+nQtVpK0OgNjX1W7BswXcM2Ac74OfP1kFiZJGh1/glaSGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBA2OfZG+SY0kOLDOfJDcmOZzk4SQXdONbknwtycEkjyS5btSLlyQNZ5hn9p8Fdq4wfymwrfvaDdzUjR8H/qaqXgXsAK5Jsn31S5UkrdbA2FfVvcBPVjjlcuCW6vkGsC7Jxqo6WlXf7B7jp8BBYNMoFi1JOjmjuGa/CTjSdzzPoqgnmQHOB+4bwfeTJJ2kUcQ+S4zVryaTFwFfAt5bVU8v+yDJ7iRzSeYWFhZGsCxJ0gmjiP08sKXveDPwJECSs+iF/vNVdcdKD1JVe6pqtqpmp6enR7AsSdIJo4j9PuCq7lM5O4CnqupokgCfBg5W1cdG8H0kSas0NeiEJLcCFwMbkswDHwLOAqiqm4G7gMuAw8DPgau7u74e+HPg20ke6sY+UFV3jXD9kqQhDIx9Ve0aMF/ANUuM/xdLX8+XJD3H/AlaSWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWrAwNgn2ZvkWJIDy8wnyY1JDid5OMkFfXM7kxzq5q4f5cIlScMb5pn9Z4GdK8xfCmzrvnYDNwEkWQN8spvfDuxKsv1UFitJWp2pQSdU1b1JZlY45XLglqoq4BtJ1iXZCMwAh6vqcYAkt3XnPnrKq17Gdbc9yDPHf3m6Hl6STqsXn30WH3n7a07LYw+M/RA2AUf6jue7saXG/2i5B0mym94rA17+8pevaiFP/Phn/N+zv1jVfSVp3Nads/a0PfYoYp8lxmqF8SVV1R5gD8Ds7Oyy561k37VvWM3dJGnijSL288CWvuPNwJPA2mXGJUnPsVF89HIfcFX3qZwdwFNVdRR4ANiWZGuStcAV3bmSpOfYwGf2SW4FLgY2JJkHPgScBVBVNwN3AZcBh4GfA1d3c8eTXAvcDawB9lbVI6dhD5KkAYb5NM6uAfMFXLPM3F30/jGQJI2RP0ErSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUgKFin2RnkkNJDie5fon59Um+nOThJPcneXXf3PuSPJLkQJJbk5w9yg1IkgYbGPska4BPApcC24FdSbYvOu0DwENV9RrgKuAT3X03AX8FzFbVq4E1wBWjW74kaRjDPLO/EDhcVY9X1TPAbcDli87ZDtwDUFWPATNJXtrNTQEvTDIFnAM8OZKVS5KGNkzsNwFH+o7nu7F+3wLeBpDkQuAVwOaq+iHwUeAHwFHgqar6yqkuWpJ0coaJfZYYq0XHHwbWJ3kIeA/wIHA8yXp6rwK2Ai8Dzk1y5ZLfJNmdZC7J3MLCwrDrlyQNYZjYzwNb+o43s+hSTFU9XVVXV9Vr6V2znwaeAN4MPFFVC1X1LHAH8LqlvklV7amq2aqanZ6ePvmdSJKWNUzsHwC2JdmaZC29N1j39Z+QZF03B/Au4N6qepre5ZsdSc5JEuAS4ODoli9JGsbUoBOq6niSa4G76X2aZm9VPZLk3d38zcCrgFuS/AJ4FHhnN3dfktuBbwLH6V3e2XNadiJJWlaqFl9+H7/Z2dmam5sb9zIk6YyRZH9VzS4370/QSlIDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDhop9kp1JDiU5nOT6JebXJ/lykoeT3J/k1X1z65LcnuSxJAeTXDTKDUiSBhsY+yRrgE8ClwLbgV1Jti867QPAQ1X1GuAq4BN9c58A/qOqfh/4A+DgKBYuSRreMM/sLwQOV9XjVfUMcBtw+aJztgP3AFTVY8BMkpcmeTHwRuDT3dwzVfW/o1q8JGk4w8R+E3Ck73i+G+v3LeBtAEkuBF4BbAbOAxaAzyR5MMmnkpy71DdJsjvJXJK5hYWFk9yGJGklw8Q+S4zVouMPA+uTPAS8B3gQOA5MARcAN1XV+cDPgN+45g9QVXuqaraqZqenp4dcviRpGFNDnDMPbOk73gw82X9CVT0NXA2QJMAT3dc5wHxV3dedejvLxF6SdPoM88z+AWBbkq1J1gJXAPv6T+g+cbO2O3wXcG9VPV1VPwKOJHllN3cJ8OiI1i5JGtLAZ/ZVdTzJtcDdwBpgb1U9kuTd3fzNwKuAW5L8gl7M39n3EO8BPt/9Y/A43SsASdJzJ1WLL7+P3+zsbM3NzY17GZJ0xkiyv6pml5v3J2glqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIakKoa9xp+Q5IF4PurvPsG4McjXM6ZwD1Pvtb2C+75ZL2iqqaXm3xexv5UJJmrqtlxr+O55J4nX2v7Bfc8al7GkaQGGHtJasAkxn7PuBcwBu558rW2X3DPIzVx1+wlSb9pEp/ZS5IWMfaS1ICJiX2SnUkOJTmc5Ppxr+dUJNmS5GtJDiZ5JMl13fhLkvxnku90v67vu88N3d4PJfnTvvE/TPLtbu7GJBnHnoaRZE2SB5Pc2R1P+n7XJbk9yWPdn/VFDez5fd3f6QNJbk1y9qTtOcneJMeSHOgbG9kek7wgyRe68fuSzAy1sKo647+ANcB3gfOAtcC3gO3jXtcp7GcjcEF3+7eB/wa2A38HXN+NXw98pLu9vdvzC4Ct3e/Fmm7ufuAiIMC/A5eOe38r7PuvgX8G7uyOJ32//wi8q7u9Flg3yXsGNgFPAC/sjr8I/MWk7Rl4I3ABcKBvbGR7BP4SuLm7fQXwhaHWNe7fmBH95l4E3N13fANww7jXNcL9/Rvwx8AhYGM3thE4tNR+gbu735ONwGN947uAfxj3fpbZ42bgHuBN/Dr2k7zfF3fhy6LxSd7zJuAI8BJgCrgT+JNJ3DMwsyj2I9vjiXO621P0fuI2g9Y0KZdxTvwlOmG+GzvjdS/RzgfuA15aVUcBul9/rzttuf1v6m4vHn8++jjwfuCXfWOTvN/zgAXgM92lq08lOZcJ3nNV/RD4KPAD4CjwVFV9hQnec59R7vFX96mq48BTwO8OWsCkxH6p63Vn/GdKk7wI+BLw3qp6eqVTlxirFcafV5K8BThWVfuHvcsSY2fMfjtT9F7q31RV5wM/o/fyfjln/J6769SX07tc8TLg3CRXrnSXJcbOqD0PYTV7XNX+JyX288CWvuPNwJNjWstIJDmLXug/X1V3dMP/k2RjN78RONaNL7f/+e724vHnm9cDb03yPeA24E1JPsfk7hd6a52vqvu649vpxX+S9/xm4ImqWqiqZ4E7gNcx2Xs+YZR7/NV9kkwBvwP8ZNACJiX2DwDbkmxNspbemxb7xrymVevedf80cLCqPtY3tQ94R3f7HfSu5Z8Yv6J7l34rsA24v3u5+NMkO7rHvKrvPs8bVXVDVW2uqhl6f3ZfraormdD9AlTVj4AjSV7ZDV0CPMoE75ne5ZsdSc7p1noJcJDJ3vMJo9xj/2O9nd5/L4Nf2Yz7jYwRviFyGb1PrXwX+OC413OKe3kDvZdlDwMPdV+X0bsudw/wne7Xl/Td54Pd3g/R98kEYBY40M39PUO8kTPmvV/Mr9+gnej9Aq8F5ro/538F1jew578FHuvW+0/0PoUyUXsGbqX3nsSz9J6Fv3OUewTOBv4FOEzvEzvnDbMu/3cJktSASbmMI0lagbGXpAYYe0lqgLGXpAYYe0lqgLGXpAYYe0lqwP8DUCQh/QV429IAAAAASUVORK5CYII=\n",
+      "text/plain": [
+       "<Figure size 432x288 with 1 Axes>"
+      ]
+     },
+     "metadata": {
+      "needs_background": "light"
+     },
+     "output_type": "display_data"
+    }
+   ],
    "source": [
     "import gymnasium as gym\n",
     "from stable_baselines3 import A2C\n",
+    "import matplotlib.pyplot as plt\n",
     "\n",
-    "\n",
-    "env = gym.make(\"CartPole-v1\", render_mode=\"human\")\n",
+    "env = gym.make(\"CartPole-v1\", render_mode=\"rgb_array\")#, render_mode=\"human\")\n",
     "\n",
     "model = A2C(\"MlpPolicy\", env, verbose=0)\n",
     "model.learn(total_timesteps=10_000)\n",
     "\n",
     "vec_env = model.get_env()\n",
     "obs = vec_env.reset()\n",
-    "for i in range(1000):\n",
+    "rewards = []\n",
+    "for i in tqdm(range(10000)):\n",
     "    action, _state = model.predict(obs, deterministic=True)\n",
     "    obs, reward, done, info = vec_env.step(action)\n",
+    "    rewards.append(reward)\n",
     "    #vec_env.render(\"human\")\n",
     "    # VecEnv resets automatically\n",
     "    # if done:\n",
     "    #   obs = vec_env.reset()\n",
-    "env.close()"
+    "env.close()\n",
+    "\n",
+    "plt.plot(rewards)\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "id": "42f0141f",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Using cpu device\n"
+     ]
+    },
+    {
+     "ename": "ImportError",
+     "evalue": "Missing shimmy installation. You provided an OpenAI Gym environment. Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. In order to use OpenAI Gym environments with SB3, you need to install shimmy (`pip install 'shimmy>=0.2.1'`).",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[1;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
+      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\common\\vec_env\\patch_gym.py\u001b[0m in \u001b[0;36m_patch_env\u001b[1;34m(env)\u001b[0m\n\u001b[0;32m     39\u001b[0m     \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 40\u001b[1;33m         \u001b[1;32mimport\u001b[0m \u001b[0mshimmy\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     41\u001b[0m     \u001b[1;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'shimmy'",
+      "\nThe above exception was the direct cause of the following exception:\n",
+      "\u001b[1;31mImportError\u001b[0m                               Traceback (most recent call last)",
+      "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_22904/716802256.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      5\u001b[0m \u001b[0menv\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgym\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"CartPole-v1\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      6\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 7\u001b[1;33m \u001b[0mmodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mA2C\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"MlpPolicy\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      8\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      9\u001b[0m \u001b[1;31m# Lists to store rewards during training\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\a2c\\a2c.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, policy, env, learning_rate, n_steps, gamma, gae_lambda, ent_coef, vf_coef, max_grad_norm, rms_prop_eps, use_rms_prop, use_sde, sde_sample_freq, rollout_buffer_class, rollout_buffer_kwargs, normalize_advantage, stats_window_size, tensorboard_log, policy_kwargs, verbose, seed, device, _init_setup_model)\u001b[0m\n\u001b[0;32m     90\u001b[0m         \u001b[0m_init_setup_model\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     91\u001b[0m     ):\n\u001b[1;32m---> 92\u001b[1;33m         super().__init__(\n\u001b[0m\u001b[0;32m     93\u001b[0m             \u001b[0mpolicy\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     94\u001b[0m             \u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\common\\on_policy_algorithm.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, policy, env, learning_rate, n_steps, gamma, gae_lambda, ent_coef, vf_coef, max_grad_norm, use_sde, sde_sample_freq, rollout_buffer_class, rollout_buffer_kwargs, stats_window_size, tensorboard_log, monitor_wrapper, policy_kwargs, verbose, seed, device, _init_setup_model, supported_action_spaces)\u001b[0m\n\u001b[0;32m     83\u001b[0m         \u001b[0msupported_action_spaces\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mTuple\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mType\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mspaces\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mSpace\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m...\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     84\u001b[0m     ):\n\u001b[1;32m---> 85\u001b[1;33m         super().__init__(\n\u001b[0m\u001b[0;32m     86\u001b[0m             \u001b[0mpolicy\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mpolicy\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     87\u001b[0m             \u001b[0menv\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\common\\base_class.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, policy, env, learning_rate, policy_kwargs, stats_window_size, tensorboard_log, verbose, device, support_multi_env, monitor_wrapper, seed, use_sde, sde_sample_freq, supported_action_spaces)\u001b[0m\n\u001b[0;32m    167\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0menv\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    168\u001b[0m             \u001b[0menv\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmaybe_make_env\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mverbose\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 169\u001b[1;33m             \u001b[0menv\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_wrap_env\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mverbose\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmonitor_wrapper\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    170\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    171\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mobservation_space\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mobservation_space\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\common\\base_class.py\u001b[0m in \u001b[0;36m_wrap_env\u001b[1;34m(env, verbose, monitor_wrapper)\u001b[0m\n\u001b[0;32m    214\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mVecEnv\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    215\u001b[0m             \u001b[1;31m# Patch to support gym 0.21/0.26 and gymnasium\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 216\u001b[1;33m             \u001b[0menv\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_patch_env\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    217\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mis_wrapped\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mMonitor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mmonitor_wrapper\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    218\u001b[0m                 \u001b[1;32mif\u001b[0m \u001b[0mverbose\u001b[0m \u001b[1;33m>=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\common\\vec_env\\patch_gym.py\u001b[0m in \u001b[0;36m_patch_env\u001b[1;34m(env)\u001b[0m\n\u001b[0;32m     40\u001b[0m         \u001b[1;32mimport\u001b[0m \u001b[0mshimmy\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     41\u001b[0m     \u001b[1;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 42\u001b[1;33m         raise ImportError(\n\u001b[0m\u001b[0;32m     43\u001b[0m             \u001b[1;34m\"Missing shimmy installation. You provided an OpenAI Gym environment. \"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     44\u001b[0m             \u001b[1;34m\"Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. \"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;31mImportError\u001b[0m: Missing shimmy installation. You provided an OpenAI Gym environment. Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. In order to use OpenAI Gym environments with SB3, you need to install shimmy (`pip install 'shimmy>=0.2.1'`)."
+     ]
+    }
+   ],
+   "source": [
+    "import gym\n",
+    "from stable_baselines3 import A2C\n",
+    "import matplotlib.pyplot as plt\n",
+    "\n",
+    "env = gym.make(\"CartPole-v1\")\n",
+    "\n",
+    "model = A2C(\"MlpPolicy\", env, verbose=1)\n",
+    "\n",
+    "# Lists to store rewards during training\n",
+    "episode_rewards = []\n",
+    "\n",
+    "total_timesteps = 10000\n",
+    "for timestep in range(total_timesteps):\n",
+    "    # Train for one step\n",
+    "    obs = env.reset()\n",
+    "    episode_reward = 0\n",
+    "    done = False\n",
+    "    while not done:\n",
+    "        action, _ = model.predict(obs, deterministic=True)\n",
+    "        obs, reward, done, _ = env.step(action)\n",
+    "        episode_reward += reward\n",
+    "    \n",
+    "    # Log episode reward\n",
+    "    episode_rewards.append(episode_reward)\n",
+    "\n",
+    "    # Display training progress\n",
+    "    if timestep % 1000 == 0:\n",
+    "        print(f\"Timestep: {timestep}/{total_timesteps}\")\n",
+    "\n",
+    "# Plot rewards over training\n",
+    "plt.plot(episode_rewards)\n",
+    "plt.xlabel('Episode')\n",
+    "plt.ylabel('Episode Reward')\n",
+    "plt.title('Evolution of Episode Reward during Training')\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "id": "6d4703d4",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Using cpu device\n"
+     ]
+    },
+    {
+     "ename": "AttributeError",
+     "evalue": "'A2C' object has no attribute '_logger'",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
+      "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_22904/1163910362.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     19\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mtimestep\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtotal_timesteps\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     20\u001b[0m     \u001b[1;31m# Train for one step\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 21\u001b[1;33m     \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     22\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     23\u001b[0m     \u001b[1;31m# Log mean episode reward every 1000 timesteps\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\a2c\\a2c.py\u001b[0m in \u001b[0;36mtrain\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    139\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    140\u001b[0m         \u001b[1;31m# Update optimizer learning rate\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 141\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_update_learning_rate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpolicy\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    142\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    143\u001b[0m         \u001b[1;31m# This will only loop once (get all data in one go)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\common\\base_class.py\u001b[0m in \u001b[0;36m_update_learning_rate\u001b[1;34m(self, optimizers)\u001b[0m\n\u001b[0;32m    293\u001b[0m         \"\"\"\n\u001b[0;32m    294\u001b[0m         \u001b[1;31m# Log the current learning rate\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 295\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlogger\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrecord\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"train/learning_rate\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlr_schedule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_current_progress_remaining\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    296\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    297\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moptimizers\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\common\\base_class.py\u001b[0m in \u001b[0;36mlogger\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    269\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mlogger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mLogger\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    270\u001b[0m         \u001b[1;34m\"\"\"Getter for the logger object.\"\"\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 271\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_logger\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    272\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    273\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m_setup_lr_schedule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;31mAttributeError\u001b[0m: 'A2C' object has no attribute '_logger'"
+     ]
+    }
+   ],
+   "source": [
+    "import gym\n",
+    "from stable_baselines3 import A2C\n",
+    "from stable_baselines3.common.env_util import make_vec_env\n",
+    "from stable_baselines3.common.evaluation import evaluate_policy\n",
+    "import matplotlib.pyplot as plt\n",
+    "\n",
+    "# Create and wrap the CartPole environment\n",
+    "env = make_vec_env('CartPole-v1', n_envs=4)\n",
+    "\n",
+    "# Define the A2C model\n",
+    "model = A2C('MlpPolicy', env, verbose=1)\n",
+    "\n",
+    "# Lists to store metrics\n",
+    "mean_rewards = []\n",
+    "entropies = []\n",
+    "\n",
+    "# Train the model\n",
+    "total_timesteps = 10000\n",
+    "for timestep in range(total_timesteps):\n",
+    "    # Train for one step\n",
+    "    model.train()\n",
+    "    \n",
+    "    # Log mean episode reward every 1000 timesteps\n",
+    "    if timestep % 1000 == 0:\n",
+    "        mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10)\n",
+    "        mean_rewards.append(mean_reward)\n",
+    "    \n",
+    "    # Log entropy every 100 timesteps\n",
+    "    if timestep % 100 == 0:\n",
+    "        action_probs = model.policy.action_net(torch.FloatTensor(env.reset()))\n",
+    "        entropy = -(action_probs * action_probs.log()).sum(dim=-1).mean().item()\n",
+    "        entropies.append(entropy)\n",
+    "\n",
+    "# Evaluate the model\n",
+    "mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10)\n",
+    "print(f\"Mean reward: {mean_reward}\")\n",
+    "\n",
+    "# Plotting\n",
+    "def plot_metrics(metrics, ylabel, title):\n",
+    "    plt.plot(metrics)\n",
+    "    plt.xlabel('Timestep')\n",
+    "    plt.ylabel(ylabel)\n",
+    "    plt.title(title)\n",
+    "    plt.show()\n",
+    "\n",
+    "plot_metrics(mean_rewards, 'Mean Reward', 'Mean Reward Over Time')\n",
+    "plot_metrics(entropies, 'Entropy', 'Entropy of Action Distribution Over Time')"
    ]
   },
   {
@@ -364,10 +635,54 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 6,
    "id": "cd890835",
-   "metadata": {},
-   "outputs": [],
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Collecting huggingface-sb3==2.3.1\n",
+      "  Downloading huggingface_sb3-2.3.1-py3-none-any.whl (9.5 kB)\n",
+      "Collecting huggingface-hub~=0.8\n",
+      "  Downloading huggingface_hub-0.20.3-py3-none-any.whl (330 kB)\n",
+      "Requirement already satisfied: pyyaml~=6.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-sb3==2.3.1) (6.0)\n",
+      "Collecting wasabi\n",
+      "  Downloading wasabi-1.1.2-py3-none-any.whl (27 kB)\n",
+      "Requirement already satisfied: cloudpickle>=1.6 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-sb3==2.3.1) (2.0.0)\n",
+      "Requirement already satisfied: numpy in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-sb3==2.3.1) (1.25.0)\n",
+      "Requirement already satisfied: packaging>=20.9 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (21.0)\n",
+      "Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (4.4.0)\n",
+      "Requirement already satisfied: requests in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (2.26.0)\n",
+      "Requirement already satisfied: filelock in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (3.3.1)\n",
+      "Requirement already satisfied: tqdm>=4.42.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (4.62.3)\n",
+      "Collecting fsspec>=2023.5.0\n",
+      "  Downloading fsspec-2024.2.0-py3-none-any.whl (170 kB)\n",
+      "Requirement already satisfied: pyparsing>=2.0.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from packaging>=20.9->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (3.0.4)\n",
+      "Requirement already satisfied: colorama in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tqdm>=4.42.1->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (0.4.4)\n",
+      "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (2021.10.8)\n",
+      "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (2.0.4)\n",
+      "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (3.2)\n",
+      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (1.26.7)\n",
+      "Collecting colorama\n",
+      "  Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)\n",
+      "Installing collected packages: colorama, fsspec, wasabi, huggingface-hub, huggingface-sb3\n",
+      "  Attempting uninstall: colorama\n",
+      "    Found existing installation: colorama 0.4.4\n",
+      "    Uninstalling colorama-0.4.4:\n",
+      "      Successfully uninstalled colorama-0.4.4\n",
+      "  Attempting uninstall: fsspec\n",
+      "    Found existing installation: fsspec 2021.10.1\n",
+      "    Uninstalling fsspec-2021.10.1:\n",
+      "      Successfully uninstalled fsspec-2021.10.1\n",
+      "Successfully installed colorama-0.4.6 fsspec-2024.2.0 huggingface-hub-0.20.3 huggingface-sb3-2.3.1 wasabi-1.1.2\n",
+      "Note: you may need to restart the kernel to use updated packages.\n"
+     ]
+    }
+   ],
    "source": [
     "pip install huggingface-sb3==2.3.1"
    ]
@@ -398,14 +713,238 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 7,
    "id": "6645c23a",
-   "metadata": {},
-   "outputs": [],
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Collecting wandb\n",
+      "  Downloading wandb-0.16.3-py3-none-any.whl (2.2 MB)\n",
+      "Collecting tensorboard\n",
+      "  Downloading tensorboard-2.16.2-py3-none-any.whl (5.5 MB)\n",
+      "Note: you may need to restart the kernel to use updated packages.\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+      "conda-repo-cli 1.0.4 requires pathlib, which is not installed.\n",
+      "anaconda-project 0.10.1 requires ruamel-yaml, which is not installed.\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Requirement already satisfied: requests<3,>=2.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (2.26.0)\n",
+      "Requirement already satisfied: setuptools in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (58.0.4)\n",
+      "Requirement already satisfied: appdirs>=1.4.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (1.4.4)\n",
+      "Requirement already satisfied: PyYAML in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (6.0)\n",
+      "Requirement already satisfied: typing-extensions in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (4.4.0)\n",
+      "Collecting GitPython!=3.1.29,>=1.0.0\n",
+      "  Downloading GitPython-3.1.42-py3-none-any.whl (195 kB)\n",
+      "Collecting sentry-sdk>=1.0.0\n",
+      "  Downloading sentry_sdk-1.40.5-py2.py3-none-any.whl (258 kB)\n",
+      "Requirement already satisfied: Click!=8.0.0,>=7.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (8.0.3)\n",
+      "Collecting setproctitle\n",
+      "  Downloading setproctitle-1.3.3-cp39-cp39-win_amd64.whl (11 kB)\n",
+      "Collecting protobuf!=4.21.0,<5,>=3.19.0\n",
+      "  Downloading protobuf-4.25.3-cp39-cp39-win_amd64.whl (413 kB)\n",
+      "Requirement already satisfied: psutil>=5.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (5.8.0)\n",
+      "Collecting docker-pycreds>=0.4.0\n",
+      "  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n",
+      "Requirement already satisfied: six>1.9 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tensorboard) (1.16.0)\n",
+      "Collecting tensorboard-data-server<0.8.0,>=0.7.0\n",
+      "  Downloading tensorboard_data_server-0.7.2-py3-none-any.whl (2.4 kB)\n",
+      "Requirement already satisfied: werkzeug>=1.0.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tensorboard) (2.0.2)\n",
+      "Collecting markdown>=2.6.8\n",
+      "  Downloading Markdown-3.5.2-py3-none-any.whl (103 kB)\n",
+      "Collecting absl-py>=0.4\n",
+      "  Downloading absl_py-2.1.0-py3-none-any.whl (133 kB)\n",
+      "Requirement already satisfied: numpy>=1.12.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tensorboard) (1.25.0)\n",
+      "Collecting grpcio>=1.48.2\n",
+      "  Downloading grpcio-1.62.0-cp39-cp39-win_amd64.whl (3.8 MB)\n",
+      "Requirement already satisfied: colorama in c:\\users\\coren\\anaconda3\\lib\\site-packages (from Click!=8.0.0,>=7.1->wandb) (0.4.6)\n",
+      "Collecting gitdb<5,>=4.0.1\n",
+      "  Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n",
+      "Collecting smmap<6,>=3.0.1\n",
+      "  Downloading smmap-5.0.1-py3-none-any.whl (24 kB)\n",
+      "Requirement already satisfied: importlib-metadata>=4.4 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from markdown>=2.6.8->tensorboard) (4.8.1)\n",
+      "Requirement already satisfied: zipp>=0.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard) (3.6.0)\n",
+      "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (2021.10.8)\n",
+      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (1.26.7)\n",
+      "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (2.0.4)\n",
+      "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (3.2)\n",
+      "Collecting urllib3<1.27,>=1.21.1\n",
+      "  Downloading urllib3-1.26.18-py2.py3-none-any.whl (143 kB)\n",
+      "Installing collected packages: smmap, urllib3, gitdb, tensorboard-data-server, setproctitle, sentry-sdk, protobuf, markdown, grpcio, GitPython, docker-pycreds, absl-py, wandb, tensorboard\n",
+      "  Attempting uninstall: urllib3\n",
+      "    Found existing installation: urllib3 1.26.7\n",
+      "    Uninstalling urllib3-1.26.7:\n",
+      "      Successfully uninstalled urllib3-1.26.7\n",
+      "Successfully installed GitPython-3.1.42 absl-py-2.1.0 docker-pycreds-0.4.0 gitdb-4.0.11 grpcio-1.62.0 markdown-3.5.2 protobuf-4.25.3 sentry-sdk-1.40.5 setproctitle-1.3.3 smmap-5.0.1 tensorboard-2.16.2 tensorboard-data-server-0.7.2 urllib3-1.26.18 wandb-0.16.3\n"
+     ]
+    }
+   ],
    "source": [
     "pip install wandb tensorboard"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "id": "f94466df",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\u001b[34m\u001b[1mwandb\u001b[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)\n",
+      "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n",
+      "\u001b[34m\u001b[1mwandb\u001b[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:\n",
+      "Traceback (most recent call last):\n",
+      "  File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\", line 1172, in init\n",
+      "    wi.setup(kwargs)\n",
+      "  File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\", line 306, in setup\n",
+      "    wandb_login._login(\n",
+      "  File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\", line 317, in _login\n",
+      "    wlogin.prompt_api_key()\n",
+      "  File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\", line 240, in prompt_api_key\n",
+      "    key, status = self._prompt_api_key()\n",
+      "  File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\", line 220, in _prompt_api_key\n",
+      "    key = apikey.prompt_api_key(\n",
+      "  File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\lib\\apikey.py\", line 151, in prompt_api_key\n",
+      "    key = input_callback(api_ask).strip()\n",
+      "  File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\click\\termui.py\", line 168, in prompt\n",
+      "    value = prompt_func(prompt)\n",
+      "  File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\click\\termui.py\", line 150, in prompt_func\n",
+      "    raise Abort() from None\n",
+      "click.exceptions.Abort\n"
+     ]
+    },
+    {
+     "ename": "Error",
+     "evalue": "An unexpected error occurred",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[1;31mAbort\u001b[0m                                     Traceback (most recent call last)",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\u001b[0m in \u001b[0;36minit\u001b[1;34m(job_type, dir, config, project, entity, reinit, tags, group, name, notes, magic, config_exclude_keys, config_include_keys, anonymous, mode, allow_val_change, resume, force, tensorboard, sync_tensorboard, monitor_gym, save_code, id, settings)\u001b[0m\n\u001b[0;32m   1171\u001b[0m         \u001b[0mwi\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_WandbInit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1172\u001b[1;33m         \u001b[0mwi\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msetup\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1173\u001b[0m         \u001b[1;32massert\u001b[0m \u001b[0mwi\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msettings\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\u001b[0m in \u001b[0;36msetup\u001b[1;34m(self, kwargs)\u001b[0m\n\u001b[0;32m    305\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0msettings\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_offline\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0msettings\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_noop\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 306\u001b[1;33m             wandb_login._login(\n\u001b[0m\u001b[0;32m    307\u001b[0m                 \u001b[0manonymous\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"anonymous\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\u001b[0m in \u001b[0;36m_login\u001b[1;34m(anonymous, key, relogin, host, force, timeout, _backend, _silent, _disable_warning, _entity)\u001b[0m\n\u001b[0;32m    316\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mkey\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 317\u001b[1;33m         \u001b[0mwlogin\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mprompt_api_key\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    318\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\u001b[0m in \u001b[0;36mprompt_api_key\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    239\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mprompt_api_key\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 240\u001b[1;33m         \u001b[0mkey\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstatus\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_prompt_api_key\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    241\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mstatus\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mApiKeyStatus\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mNOTTY\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\u001b[0m in \u001b[0;36m_prompt_api_key\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    219\u001b[0m             \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 220\u001b[1;33m                 key = apikey.prompt_api_key(\n\u001b[0m\u001b[0;32m    221\u001b[0m                     \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_settings\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\lib\\apikey.py\u001b[0m in \u001b[0;36mprompt_api_key\u001b[1;34m(settings, api, input_callback, browser_callback, no_offline, no_create, local)\u001b[0m\n\u001b[0;32m    150\u001b[0m             )\n\u001b[1;32m--> 151\u001b[1;33m             \u001b[0mkey\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minput_callback\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mapi_ask\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstrip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    152\u001b[0m         \u001b[0mwrite_key\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msettings\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mapi\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mapi\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\click\\termui.py\u001b[0m in \u001b[0;36mprompt\u001b[1;34m(text, default, hide_input, confirmation_prompt, type, value_proc, prompt_suffix, show_default, err, show_choices)\u001b[0m\n\u001b[0;32m    167\u001b[0m         \u001b[1;32mwhile\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 168\u001b[1;33m             \u001b[0mvalue\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mprompt_func\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprompt\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    169\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mvalue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\click\\termui.py\u001b[0m in \u001b[0;36mprompt_func\u001b[1;34m(text)\u001b[0m\n\u001b[0;32m    149\u001b[0m                 \u001b[0mecho\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0merr\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0merr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 150\u001b[1;33m             \u001b[1;32mraise\u001b[0m \u001b[0mAbort\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    151\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;31mAbort\u001b[0m: ",
+      "\nThe above exception was the direct cause of the following exception:\n",
+      "\u001b[1;31mError\u001b[0m                                     Traceback (most recent call last)",
+      "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_22904/2820517193.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     12\u001b[0m     \u001b[1;34m\"env_name\"\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;34m\"CartPole-v1\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     13\u001b[0m }\n\u001b[1;32m---> 14\u001b[1;33m run = wandb.init(\n\u001b[0m\u001b[0;32m     15\u001b[0m     \u001b[0mproject\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m\"sb3\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     16\u001b[0m     \u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\u001b[0m in \u001b[0;36minit\u001b[1;34m(job_type, dir, config, project, entity, reinit, tags, group, name, notes, magic, config_exclude_keys, config_include_keys, anonymous, mode, allow_val_change, resume, force, tensorboard, sync_tensorboard, monitor_gym, save_code, id, settings)\u001b[0m\n\u001b[0;32m   1212\u001b[0m                 \u001b[0mwandb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtermerror\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Abnormal program exit\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1213\u001b[0m                 \u001b[0mos\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_exit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1214\u001b[1;33m             \u001b[1;32mraise\u001b[0m \u001b[0mError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"An unexpected error occurred\"\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0merror_seen\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1215\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0mrun\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;31mError\u001b[0m: An unexpected error occurred"
+     ]
+    }
+   ],
+   "source": [
+    "import gym\n",
+    "from stable_baselines3 import PPO\n",
+    "from stable_baselines3.common.monitor import Monitor\n",
+    "from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder\n",
+    "import wandb\n",
+    "from wandb.integration.sb3 import WandbCallback\n",
+    "\n",
+    "\n",
+    "config = {\n",
+    "    \"policy_type\": \"MlpPolicy\",\n",
+    "    \"total_timesteps\": 25000,\n",
+    "    \"env_name\": \"CartPole-v1\",\n",
+    "}\n",
+    "run = wandb.init(\n",
+    "    project=\"sb3\",\n",
+    "    config=config,\n",
+    "    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics\n",
+    "    monitor_gym=True,  # auto-upload the videos of agents playing the game\n",
+    "    save_code=True,  # optional\n",
+    ")\n",
+    "\n",
+    "\n",
+    "def make_env():\n",
+    "    env = gym.make(config[\"env_name\"])\n",
+    "    env = Monitor(env)  # record stats such as returns\n",
+    "    return env\n",
+    "\n",
+    "\n",
+    "env = DummyVecEnv([make_env])\n",
+    "env = VecVideoRecorder(\n",
+    "    env,\n",
+    "    f\"videos/{run.id}\",\n",
+    "    record_video_trigger=lambda x: x % 2000 == 0,\n",
+    "    video_length=200,\n",
+    ")\n",
+    "model = PPO(config[\"policy_type\"], env, verbose=1, tensorboard_log=f\"runs/{run.id}\")\n",
+    "model.learn(\n",
+    "    total_timesteps=config[\"total_timesteps\"],\n",
+    "    callback=WandbCallback(\n",
+    "        gradient_save_freq=100,\n",
+    "        model_save_path=f\"models/{run.id}\",\n",
+    "        verbose=2,\n",
+    "    ),\n",
+    ")\n",
+    "run.finish()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "id": "4dda548e",
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "Error",
+     "evalue": "You must call wandb.init() before WandbCallback()",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[1;31mError\u001b[0m                                     Traceback (most recent call last)",
+      "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_22904/2131705787.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      8\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      9\u001b[0m \u001b[0mmodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mA2C\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"MlpPolicy\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 10\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlearn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtotal_timesteps\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m10_000\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mWandbCallback\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     11\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     12\u001b[0m \u001b[0mvec_env\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_env\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\integration\\sb3\\sb3.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, verbose, model_save_path, model_save_freq, gradient_save_freq, log)\u001b[0m\n\u001b[0;32m     95\u001b[0m         \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mverbose\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     96\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mwandb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 97\u001b[1;33m             \u001b[1;32mraise\u001b[0m \u001b[0mwandb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"You must call wandb.init() before WandbCallback()\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     98\u001b[0m         \u001b[1;32mwith\u001b[0m \u001b[0mwb_telemetry\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcontext\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mtel\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     99\u001b[0m             \u001b[0mtel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfeature\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msb3\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;31mError\u001b[0m: You must call wandb.init() before WandbCallback()"
+     ]
+    }
+   ],
+   "source": [
+    "from wandb.integration.sb3 import WandbCallback\n",
+    "\n",
+    "import gymnasium as gym\n",
+    "from stable_baselines3 import A2C\n",
+    "\n",
+    "\n",
+    "env = gym.make(\"CartPole-v1\", render_mode=\"human\")\n",
+    "\n",
+    "model = A2C(\"MlpPolicy\", env, verbose=0)\n",
+    "model.learn(total_timesteps=10_000, callback=WandbCallback())\n",
+    "\n",
+    "vec_env = model.get_env()\n",
+    "obs = vec_env.reset()\n",
+    "for i in tqdm(range(1000)):\n",
+    "    action, _state = model.predict(obs, deterministic=True)\n",
+    "    obs, reward, done, info = vec_env.step(action)\n",
+    "    #vec_env.render(\"human\")\n",
+    "    # VecEnv resets automatically\n",
+    "    # if done:\n",
+    "    #   obs = vec_env.reset()\n",
+    "env.close()"
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "9af0d167",