diff --git a/.ipynb_checkpoints/Rendu_TP1_GEREST-checkpoint.ipynb b/.ipynb_checkpoints/Rendu_TP1_GEREST-checkpoint.ipynb index d05bc3577ae8b4d91d45b28ab935367fbaea9ff6..68a4f33337ff06254919d61721a04bff8b76a8b8 100644 --- a/.ipynb_checkpoints/Rendu_TP1_GEREST-checkpoint.ipynb +++ b/.ipynb_checkpoints/Rendu_TP1_GEREST-checkpoint.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "e687aca8", + "id": "182bf66e", "metadata": {}, "source": [ "GEREST CORENTIN\n", @@ -49,125 +49,46 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 1, "id": "c49497c0", - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: numpy==1.25.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (1.25.0)\n", - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } - ], - "source": [ - "pip install numpy==1.25.0 --user" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "bb4d7c39", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ "Requirement already satisfied: gym==0.26.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (0.26.2)\n", "Requirement already satisfied: importlib-metadata>=4.8.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (4.8.1)\n", - "Requirement already satisfied: numpy>=1.18.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (1.25.0)\n", - "Requirement already satisfied: cloudpickle>=1.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (2.0.0)\n", "Requirement already satisfied: gym-notices>=0.0.4 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (0.0.8)\n", + "Requirement already satisfied: cloudpickle>=1.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (2.0.0)\n", + "Requirement already satisfied: numpy>=1.18.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (1.25.0)\n", "Requirement already satisfied: zipp>=0.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from importlib-metadata>=4.8.0->gym==0.26.2) (3.6.0)\n", - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } - ], - "source": [ - "pip install gym==0.26.2" - ] - }, - { - "cell_type": "markdown", - "id": "c4339dd2", - "metadata": {}, - "source": [ - "Install also pyglet for the rendering." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "ae74426f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ "Requirement already satisfied: pyglet==2.0.10 in c:\\users\\coren\\anaconda3\\lib\\site-packages (2.0.10)\n", - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } - ], - "source": [ - "pip install pyglet==2.0.10" - ] - }, - { - "cell_type": "markdown", - "id": "d6a2d90b", - "metadata": {}, - "source": [ - "If needed " - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "712fb75a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ "Requirement already satisfied: pygame==2.5.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (2.5.2)\n", - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } - ], - "source": [ - "pip install pygame==2.5.2" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "3cdd7bcc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ "Requirement already satisfied: PyQt5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (5.15.10)\n", "Requirement already satisfied: PyQt5-Qt5>=5.15.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from PyQt5) (5.15.2)\n", - "Note: you may need to restart the kernel to use updated packages.Requirement already satisfied: PyQt5-sip<13,>=12.13 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from PyQt5) (12.13.0)\n", - "\n" + "Requirement already satisfied: PyQt5-sip<13,>=12.13 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from PyQt5) (12.13.0)\n" ] } ], "source": [ - "pip install PyQt5" + "!pip install numpy==1.25.0 --user\n", + "!pip install gym==0.26.2\n", + "\n", + "# Install also pyglet for the rendering.\n", + "!pip install pyglet==2.0.10\n", + "\n", + "# If needed\n", + "!pip install pygame==2.5.2\n", + "!pip install PyQt5" ] }, { "cell_type": "markdown", - "id": "79581e82", + "id": "51bd01b6", "metadata": {}, "source": [ "### Usage\n", @@ -177,10 +98,23 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 2, "id": "800853bd", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\coren\\anaconda3\\lib\\site-packages\\numpy\\_distributor_init.py:30: UserWarning: loaded more than 1 DLL from .libs:\n", + "C:\\Users\\coren\\anaconda3\\lib\\site-packages\\numpy\\.libs\\libopenblas.4SP5SUA7CBGXUEOC35YP2ASOICYYEQZZ.gfortran-win_amd64.dll\n", + "C:\\Users\\coren\\anaconda3\\lib\\site-packages\\numpy\\.libs\\libopenblas64__v0.3.23-gcc_10_3_0.dll\n", + " warnings.warn(\"loaded more than 1 DLL from .libs:\"\n", + "C:\\Users\\coren\\anaconda3\\lib\\site-packages\\gym\\utils\\passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)\n", + " if not isinstance(terminated, (bool, np.bool8)):\n" + ] + } + ], "source": [ "import gym\n", "\n", @@ -208,7 +142,7 @@ }, { "cell_type": "markdown", - "id": "fea5e4cf", + "id": "5aa96f10", "metadata": {}, "source": [ "## REINFORCE\n", @@ -236,8 +170,168 @@ "To learn more about REINFORCE, you can refer to [this unit](https://huggingface.co/learn/deep-rl-course/unit4/introduction).\n", "\n", "> 🛠**To be handed in**\n", - "> Use PyTorch to implement REINFORCE and solve the CartPole environement. Share the code in `reinforce_cartpole.py`, and share a plot showing the total reward accross episodes in the `README.md`.\n", + "> Use PyTorch to implement REINFORCE and solve the CartPole environement. Share the code in `reinforce_cartpole.py`, and share a plot showing the total reward accross episodes in the `README.md`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f3e912fb", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c901dca8dc49499f91778a2439ab7250", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/500 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\coren\\AppData\\Local\\Temp/ipykernel_22904/2002979212.py:58: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\torch\\csrc\\utils\\tensor_new.cpp:264.)\n", + " state_tensor = torch.FloatTensor(state).unsqueeze(0)\n" + ] + }, + { + "ename": "ValueError", + "evalue": "expected sequence of length 4 at dim 1 (got 0)", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_22904/2002979212.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 56\u001b[0m \u001b[1;32mwhile\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 57\u001b[0m \u001b[1;31m# Compute action probabilities\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 58\u001b[1;33m \u001b[0mstate_tensor\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 59\u001b[0m \u001b[0maction_probs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpolicy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate_tensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 60\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mValueError\u001b[0m: expected sequence of length 4 at dim 1 (got 0)" + ] + } + ], + "source": [ + "import gym\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import numpy as np\n", + "from tqdm.notebook import tqdm\n", + "\n", + "\n", + "# Define the neural network for the policy\n", + "class PolicyNetwork(nn.Module):\n", + " def __init__(self, input_dim, output_dim):\n", + " super(PolicyNetwork, self).__init__()\n", + " self.fc1 = nn.Linear(input_dim, 128)\n", + " self.relu = nn.ReLU()\n", + " self.dropout = nn.Dropout(p=0.6)\n", + " self.fc2 = nn.Linear(128, output_dim)\n", + " self.softmax = nn.Softmax(dim=1)\n", + "\n", + " def forward(self, x):\n", + " x = self.fc1(x)\n", + " x = self.relu(x)\n", + " x = self.dropout(x)\n", + " x = self.fc2(x)\n", + " return self.softmax(x)\n", + "\n", + "# Normalize function\n", + "def normalize_rewards(rewards):\n", + " rewards = np.array(rewards)\n", + " rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-9)\n", + " return rewards\n", + "\n", + "# Hyperparameters\n", + "learning_rate = 5e-3\n", + "gamma = 0.99\n", + "episodes = 500\n", + "\n", + "# Environment setup\n", + "env = gym.make(\"CartPole-v1\", render_mode=\"human\")\n", + "\n", + "input_dim = env.observation_space.shape[0]\n", + "output_dim = env.action_space.n\n", + "\n", + "# Policy network\n", + "policy = PolicyNetwork(input_dim, output_dim)\n", + "optimizer = optim.Adam(policy.parameters(), lr=learning_rate)\n", + "\n", + "# Training loop\n", + "episode_rewards = []\n", + "for episode in tqdm(range(episodes)):\n", + "\n", + " state = env.reset()\n", + " episode_reward = 0\n", + " saved_log_probs = []\n", + " rewards = []\n", + "\n", + " while True:\n", + " # Compute action probabilities\n", + " state_tensor = torch.FloatTensor(state).unsqueeze(0)\n", + " action_probs = policy(state_tensor)\n", "\n", + " # Sample action\n", + " action = torch.multinomial(action_probs, 1).item()\n", + " log_prob = torch.log(action_probs.squeeze(0)[action])\n", + " saved_log_probs.append(log_prob)\n", + "\n", + " # Step env with action\n", + " next_state, reward, done, _ = env.step(action)\n", + " rewards.append(reward)\n", + " episode_reward += reward\n", + "\n", + "\n", + " if done:\n", + " # Compute return\n", + " returns = []\n", + " R = 0\n", + " for r in rewards[::-1]:\n", + " R = r + gamma * R\n", + " returns.insert(0, R)\n", + "\n", + " # Normalize returns\n", + " returns = normalize_rewards(returns)\n", + "\n", + " # Update policy\n", + " policy_loss = torch.stack(saved_log_probs) * torch.tensor(returns)\n", + " policy_loss = -policy_loss.sum()\n", + " optimizer.zero_grad()\n", + " policy_loss.backward()\n", + " optimizer.step()\n", + "\n", + " episode_rewards.append(episode_reward)\n", + " break\n", + "\n", + " state = next_state" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06261130", + "metadata": {}, + "outputs": [], + "source": [ + "# Plotting\n", + "import matplotlib.pyplot as plt\n", + "\n", + "plt.plot(episode_rewards)\n", + "plt.xlabel('Episode')\n", + "plt.ylabel('Total Reward')\n", + "plt.title('REINFORCE on CartPole')\n", + "plt.savefig('reinforce_cartpole_rewards2.png')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "fea5e4cf", + "metadata": {}, + "source": [ "## Familiarization with a complete RL pipeline: Application to training a robotic arm\n", "\n", "In this section, you will use the Stable-Baselines3 package to train a robotic arm using RL. You'll get familiar with several widely-used tools for training, monitoring and sharing machine learning models.\n", @@ -251,7 +345,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 4, "id": "7b5d4e63", "metadata": { "scrolled": true @@ -261,50 +355,45 @@ "name": "stdout", "output_type": "stream", "text": [ - "Collecting stable-baselines3\n", - " Using cached stable_baselines3-2.2.1-py3-none-any.whl (181 kB)\n", + "Requirement already satisfied: stable-baselines3 in c:\\users\\coren\\appdata\\roaming\\python\\python39\\site-packages (2.2.1)\n", + "Requirement already satisfied: gymnasium<0.30,>=0.28.1 in c:\\users\\coren\\appdata\\roaming\\python\\python39\\site-packages (from stable-baselines3) (0.29.1)\n", "Requirement already satisfied: matplotlib in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (3.4.3)\n", - "Requirement already satisfied: numpy>=1.20 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (1.26.4)\n", + "Requirement already satisfied: numpy>=1.20 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (1.25.0)\n", "Requirement already satisfied: pandas in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (1.3.4)\n", - "Requirement already satisfied: torch>=1.13 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (2.1.1)\n", "Requirement already satisfied: cloudpickle in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (2.0.0)\n", - "Collecting gymnasium<0.30,>=0.28.1\n", - " Using cached gymnasium-0.29.1-py3-none-any.whl (953 kB)\n", - "Collecting farama-notifications>=0.0.1\n", - " Using cached Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)\n", - "Requirement already satisfied: importlib-metadata>=4.8.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (4.8.1)\n", + "Requirement already satisfied: torch>=1.13 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (2.1.1)\n", "Requirement already satisfied: typing-extensions>=4.3.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (4.4.0)\n", + "Requirement already satisfied: importlib-metadata>=4.8.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (4.8.1)\n", + "Requirement already satisfied: farama-notifications>=0.0.1 in c:\\users\\coren\\appdata\\roaming\\python\\python39\\site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (0.0.4)\n", "Requirement already satisfied: zipp>=0.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from importlib-metadata>=4.8.0->gymnasium<0.30,>=0.28.1->stable-baselines3) (3.6.0)\n", - "Requirement already satisfied: jinja2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2.11.3)\n", + "Requirement already satisfied: filelock in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (3.3.1)\n", "Requirement already satisfied: sympy in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (1.9)\n", - "Requirement already satisfied: fsspec in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2021.10.1)\n", "Requirement already satisfied: networkx in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2.6.3)\n", - "Requirement already satisfied: filelock in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (3.3.1)\n", + "Requirement already satisfied: jinja2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2.11.3)\n", + "Requirement already satisfied: fsspec in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2021.10.1)\n", "Requirement already satisfied: MarkupSafe>=0.23 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from jinja2->torch>=1.13->stable-baselines3) (1.1.1)\n", - "Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (2.8.2)\n", "Requirement already satisfied: pillow>=6.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (8.4.0)\n", - "Requirement already satisfied: pyparsing>=2.2.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (3.0.4)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (1.3.1)\n", + "Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (2.8.2)\n", + "Requirement already satisfied: pyparsing>=2.2.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (3.0.4)\n", "Requirement already satisfied: cycler>=0.10 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (0.10.0)\n", "Requirement already satisfied: six in c:\\users\\coren\\anaconda3\\lib\\site-packages (from cycler>=0.10->matplotlib->stable-baselines3) (1.16.0)\n", "Requirement already satisfied: pytz>=2017.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from pandas->stable-baselines3) (2021.3)\n", "Requirement already satisfied: mpmath>=0.19 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from sympy->torch>=1.13->stable-baselines3) (1.2.1)\n", - "Installing collected packages: farama-notifications, gymnasium, stable-baselines3\n", - "Successfully installed farama-notifications-0.0.4 gymnasium-0.29.1 stable-baselines3-2.2.1\n", "Requirement already satisfied: moviepy in c:\\users\\coren\\anaconda3\\lib\\site-packages (1.0.3)\n", - "Requirement already satisfied: decorator<5.0,>=4.0.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (4.4.2)\n", + "Requirement already satisfied: imageio<3.0,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (2.9.0)\n", "Requirement already satisfied: requests<3.0,>=2.8.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (2.26.0)\n", - "Requirement already satisfied: imageio-ffmpeg>=0.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (0.4.9)\n", - "Requirement already satisfied: numpy>=1.17.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (1.26.4)\n", "Requirement already satisfied: proglog<=1.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (0.1.10)\n", - "Requirement already satisfied: imageio<3.0,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (2.9.0)\n", "Requirement already satisfied: tqdm<5.0,>=4.11.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (4.62.3)\n", + "Requirement already satisfied: imageio-ffmpeg>=0.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (0.4.9)\n", + "Requirement already satisfied: numpy>=1.17.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (1.25.0)\n", + "Requirement already satisfied: decorator<5.0,>=4.0.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (4.4.2)\n", "Requirement already satisfied: pillow in c:\\users\\coren\\anaconda3\\lib\\site-packages (from imageio<3.0,>=2.5->moviepy) (8.4.0)\n", "Requirement already satisfied: setuptools in c:\\users\\coren\\anaconda3\\lib\\site-packages (from imageio-ffmpeg>=0.2.0->moviepy) (58.0.4)\n", - "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (2021.10.8)\n", "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (3.2)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (1.26.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (2021.10.8)\n", "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (2.0.4)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (1.26.7)\n", "Requirement already satisfied: colorama in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tqdm<5.0,>=4.11.2->moviepy) (0.4.4)\n" ] } @@ -316,7 +405,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "id": "14c1bc63", "metadata": {}, "outputs": [], @@ -332,7 +421,7 @@ "\n", "vec_env = model.get_env()\n", "obs = vec_env.reset()\n", - "for i in range(1000):\n", + "for i in tqdm(range(1000)):\n", " action, _state = model.predict(obs, deterministic=True)\n", " obs, reward, done, info = vec_env.step(action)\n", " #vec_env.render(\"human\")\n", @@ -364,10 +453,54 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "cd890835", - "metadata": {}, - "outputs": [], + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting huggingface-sb3==2.3.1\n", + " Downloading huggingface_sb3-2.3.1-py3-none-any.whl (9.5 kB)\n", + "Collecting huggingface-hub~=0.8\n", + " Downloading huggingface_hub-0.20.3-py3-none-any.whl (330 kB)\n", + "Requirement already satisfied: pyyaml~=6.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-sb3==2.3.1) (6.0)\n", + "Collecting wasabi\n", + " Downloading wasabi-1.1.2-py3-none-any.whl (27 kB)\n", + "Requirement already satisfied: cloudpickle>=1.6 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-sb3==2.3.1) (2.0.0)\n", + "Requirement already satisfied: numpy in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-sb3==2.3.1) (1.25.0)\n", + "Requirement already satisfied: packaging>=20.9 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (21.0)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (4.4.0)\n", + "Requirement already satisfied: requests in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (2.26.0)\n", + "Requirement already satisfied: filelock in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (3.3.1)\n", + "Requirement already satisfied: tqdm>=4.42.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (4.62.3)\n", + "Collecting fsspec>=2023.5.0\n", + " Downloading fsspec-2024.2.0-py3-none-any.whl (170 kB)\n", + "Requirement already satisfied: pyparsing>=2.0.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from packaging>=20.9->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (3.0.4)\n", + "Requirement already satisfied: colorama in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tqdm>=4.42.1->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (0.4.4)\n", + "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (2021.10.8)\n", + "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (2.0.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (3.2)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (1.26.7)\n", + "Collecting colorama\n", + " Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)\n", + "Installing collected packages: colorama, fsspec, wasabi, huggingface-hub, huggingface-sb3\n", + " Attempting uninstall: colorama\n", + " Found existing installation: colorama 0.4.4\n", + " Uninstalling colorama-0.4.4:\n", + " Successfully uninstalled colorama-0.4.4\n", + " Attempting uninstall: fsspec\n", + " Found existing installation: fsspec 2021.10.1\n", + " Uninstalling fsspec-2021.10.1:\n", + " Successfully uninstalled fsspec-2021.10.1\n", + "Successfully installed colorama-0.4.6 fsspec-2024.2.0 huggingface-hub-0.20.3 huggingface-sb3-2.3.1 wasabi-1.1.2\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], "source": [ "pip install huggingface-sb3==2.3.1" ] @@ -398,14 +531,238 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "6645c23a", - "metadata": {}, - "outputs": [], + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting wandb\n", + " Downloading wandb-0.16.3-py3-none-any.whl (2.2 MB)\n", + "Collecting tensorboard\n", + " Downloading tensorboard-2.16.2-py3-none-any.whl (5.5 MB)\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "conda-repo-cli 1.0.4 requires pathlib, which is not installed.\n", + "anaconda-project 0.10.1 requires ruamel-yaml, which is not installed.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: requests<3,>=2.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (2.26.0)\n", + "Requirement already satisfied: setuptools in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (58.0.4)\n", + "Requirement already satisfied: appdirs>=1.4.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (1.4.4)\n", + "Requirement already satisfied: PyYAML in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (6.0)\n", + "Requirement already satisfied: typing-extensions in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (4.4.0)\n", + "Collecting GitPython!=3.1.29,>=1.0.0\n", + " Downloading GitPython-3.1.42-py3-none-any.whl (195 kB)\n", + "Collecting sentry-sdk>=1.0.0\n", + " Downloading sentry_sdk-1.40.5-py2.py3-none-any.whl (258 kB)\n", + "Requirement already satisfied: Click!=8.0.0,>=7.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (8.0.3)\n", + "Collecting setproctitle\n", + " Downloading setproctitle-1.3.3-cp39-cp39-win_amd64.whl (11 kB)\n", + "Collecting protobuf!=4.21.0,<5,>=3.19.0\n", + " Downloading protobuf-4.25.3-cp39-cp39-win_amd64.whl (413 kB)\n", + "Requirement already satisfied: psutil>=5.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (5.8.0)\n", + "Collecting docker-pycreds>=0.4.0\n", + " Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n", + "Requirement already satisfied: six>1.9 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tensorboard) (1.16.0)\n", + "Collecting tensorboard-data-server<0.8.0,>=0.7.0\n", + " Downloading tensorboard_data_server-0.7.2-py3-none-any.whl (2.4 kB)\n", + "Requirement already satisfied: werkzeug>=1.0.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tensorboard) (2.0.2)\n", + "Collecting markdown>=2.6.8\n", + " Downloading Markdown-3.5.2-py3-none-any.whl (103 kB)\n", + "Collecting absl-py>=0.4\n", + " Downloading absl_py-2.1.0-py3-none-any.whl (133 kB)\n", + "Requirement already satisfied: numpy>=1.12.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tensorboard) (1.25.0)\n", + "Collecting grpcio>=1.48.2\n", + " Downloading grpcio-1.62.0-cp39-cp39-win_amd64.whl (3.8 MB)\n", + "Requirement already satisfied: colorama in c:\\users\\coren\\anaconda3\\lib\\site-packages (from Click!=8.0.0,>=7.1->wandb) (0.4.6)\n", + "Collecting gitdb<5,>=4.0.1\n", + " Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n", + "Collecting smmap<6,>=3.0.1\n", + " Downloading smmap-5.0.1-py3-none-any.whl (24 kB)\n", + "Requirement already satisfied: importlib-metadata>=4.4 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from markdown>=2.6.8->tensorboard) (4.8.1)\n", + "Requirement already satisfied: zipp>=0.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard) (3.6.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (2021.10.8)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (1.26.7)\n", + "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (2.0.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (3.2)\n", + "Collecting urllib3<1.27,>=1.21.1\n", + " Downloading urllib3-1.26.18-py2.py3-none-any.whl (143 kB)\n", + "Installing collected packages: smmap, urllib3, gitdb, tensorboard-data-server, setproctitle, sentry-sdk, protobuf, markdown, grpcio, GitPython, docker-pycreds, absl-py, wandb, tensorboard\n", + " Attempting uninstall: urllib3\n", + " Found existing installation: urllib3 1.26.7\n", + " Uninstalling urllib3-1.26.7:\n", + " Successfully uninstalled urllib3-1.26.7\n", + "Successfully installed GitPython-3.1.42 absl-py-2.1.0 docker-pycreds-0.4.0 gitdb-4.0.11 grpcio-1.62.0 markdown-3.5.2 protobuf-4.25.3 sentry-sdk-1.40.5 setproctitle-1.3.3 smmap-5.0.1 tensorboard-2.16.2 tensorboard-data-server-0.7.2 urllib3-1.26.18 wandb-0.16.3\n" + ] + } + ], "source": [ "pip install wandb tensorboard" ] }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f94466df", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:\n", + "Traceback (most recent call last):\n", + " File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\", line 1172, in init\n", + " wi.setup(kwargs)\n", + " File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\", line 306, in setup\n", + " wandb_login._login(\n", + " File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\", line 317, in _login\n", + " wlogin.prompt_api_key()\n", + " File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\", line 240, in prompt_api_key\n", + " key, status = self._prompt_api_key()\n", + " File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\", line 220, in _prompt_api_key\n", + " key = apikey.prompt_api_key(\n", + " File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\lib\\apikey.py\", line 151, in prompt_api_key\n", + " key = input_callback(api_ask).strip()\n", + " File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\click\\termui.py\", line 168, in prompt\n", + " value = prompt_func(prompt)\n", + " File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\click\\termui.py\", line 150, in prompt_func\n", + " raise Abort() from None\n", + "click.exceptions.Abort\n" + ] + }, + { + "ename": "Error", + "evalue": "An unexpected error occurred", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mAbort\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\u001b[0m in \u001b[0;36minit\u001b[1;34m(job_type, dir, config, project, entity, reinit, tags, group, name, notes, magic, config_exclude_keys, config_include_keys, anonymous, mode, allow_val_change, resume, force, tensorboard, sync_tensorboard, monitor_gym, save_code, id, settings)\u001b[0m\n\u001b[0;32m 1171\u001b[0m \u001b[0mwi\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_WandbInit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1172\u001b[1;33m \u001b[0mwi\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msetup\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1173\u001b[0m \u001b[1;32massert\u001b[0m \u001b[0mwi\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msettings\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\u001b[0m in \u001b[0;36msetup\u001b[1;34m(self, kwargs)\u001b[0m\n\u001b[0;32m 305\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0msettings\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_offline\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0msettings\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_noop\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 306\u001b[1;33m wandb_login._login(\n\u001b[0m\u001b[0;32m 307\u001b[0m \u001b[0manonymous\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"anonymous\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\u001b[0m in \u001b[0;36m_login\u001b[1;34m(anonymous, key, relogin, host, force, timeout, _backend, _silent, _disable_warning, _entity)\u001b[0m\n\u001b[0;32m 316\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mkey\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 317\u001b[1;33m \u001b[0mwlogin\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mprompt_api_key\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 318\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\u001b[0m in \u001b[0;36mprompt_api_key\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 239\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mprompt_api_key\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 240\u001b[1;33m \u001b[0mkey\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstatus\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_prompt_api_key\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 241\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mstatus\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mApiKeyStatus\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mNOTTY\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\u001b[0m in \u001b[0;36m_prompt_api_key\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 219\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 220\u001b[1;33m key = apikey.prompt_api_key(\n\u001b[0m\u001b[0;32m 221\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_settings\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\lib\\apikey.py\u001b[0m in \u001b[0;36mprompt_api_key\u001b[1;34m(settings, api, input_callback, browser_callback, no_offline, no_create, local)\u001b[0m\n\u001b[0;32m 150\u001b[0m )\n\u001b[1;32m--> 151\u001b[1;33m \u001b[0mkey\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minput_callback\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mapi_ask\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstrip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 152\u001b[0m \u001b[0mwrite_key\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msettings\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mapi\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mapi\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\click\\termui.py\u001b[0m in \u001b[0;36mprompt\u001b[1;34m(text, default, hide_input, confirmation_prompt, type, value_proc, prompt_suffix, show_default, err, show_choices)\u001b[0m\n\u001b[0;32m 167\u001b[0m \u001b[1;32mwhile\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 168\u001b[1;33m \u001b[0mvalue\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mprompt_func\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprompt\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 169\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mvalue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\click\\termui.py\u001b[0m in \u001b[0;36mprompt_func\u001b[1;34m(text)\u001b[0m\n\u001b[0;32m 149\u001b[0m \u001b[0mecho\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0merr\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0merr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 150\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mAbort\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 151\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mAbort\u001b[0m: ", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[1;31mError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_22904/2820517193.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[1;34m\"env_name\"\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;34m\"CartPole-v1\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 13\u001b[0m }\n\u001b[1;32m---> 14\u001b[1;33m run = wandb.init(\n\u001b[0m\u001b[0;32m 15\u001b[0m \u001b[0mproject\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m\"sb3\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\u001b[0m in \u001b[0;36minit\u001b[1;34m(job_type, dir, config, project, entity, reinit, tags, group, name, notes, magic, config_exclude_keys, config_include_keys, anonymous, mode, allow_val_change, resume, force, tensorboard, sync_tensorboard, monitor_gym, save_code, id, settings)\u001b[0m\n\u001b[0;32m 1212\u001b[0m \u001b[0mwandb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtermerror\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Abnormal program exit\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1213\u001b[0m \u001b[0mos\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_exit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1214\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"An unexpected error occurred\"\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0merror_seen\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1215\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mrun\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mError\u001b[0m: An unexpected error occurred" + ] + } + ], + "source": [ + "import gym\n", + "from stable_baselines3 import PPO\n", + "from stable_baselines3.common.monitor import Monitor\n", + "from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder\n", + "import wandb\n", + "from wandb.integration.sb3 import WandbCallback\n", + "\n", + "\n", + "config = {\n", + " \"policy_type\": \"MlpPolicy\",\n", + " \"total_timesteps\": 25000,\n", + " \"env_name\": \"CartPole-v1\",\n", + "}\n", + "run = wandb.init(\n", + " project=\"sb3\",\n", + " config=config,\n", + " sync_tensorboard=True, # auto-upload sb3's tensorboard metrics\n", + " monitor_gym=True, # auto-upload the videos of agents playing the game\n", + " save_code=True, # optional\n", + ")\n", + "\n", + "\n", + "def make_env():\n", + " env = gym.make(config[\"env_name\"])\n", + " env = Monitor(env) # record stats such as returns\n", + " return env\n", + "\n", + "\n", + "env = DummyVecEnv([make_env])\n", + "env = VecVideoRecorder(\n", + " env,\n", + " f\"videos/{run.id}\",\n", + " record_video_trigger=lambda x: x % 2000 == 0,\n", + " video_length=200,\n", + ")\n", + "model = PPO(config[\"policy_type\"], env, verbose=1, tensorboard_log=f\"runs/{run.id}\")\n", + "model.learn(\n", + " total_timesteps=config[\"total_timesteps\"],\n", + " callback=WandbCallback(\n", + " gradient_save_freq=100,\n", + " model_save_path=f\"models/{run.id}\",\n", + " verbose=2,\n", + " ),\n", + ")\n", + "run.finish()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4dda548e", + "metadata": {}, + "outputs": [ + { + "ename": "Error", + "evalue": "You must call wandb.init() before WandbCallback()", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_22904/2131705787.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 8\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 9\u001b[0m \u001b[0mmodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mA2C\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"MlpPolicy\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 10\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlearn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtotal_timesteps\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m10_000\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mWandbCallback\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 11\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[0mvec_env\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_env\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\integration\\sb3\\sb3.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, verbose, model_save_path, model_save_freq, gradient_save_freq, log)\u001b[0m\n\u001b[0;32m 95\u001b[0m \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mverbose\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 96\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mwandb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 97\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mwandb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"You must call wandb.init() before WandbCallback()\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 98\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mwb_telemetry\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcontext\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mtel\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[0mtel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfeature\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msb3\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mError\u001b[0m: You must call wandb.init() before WandbCallback()" + ] + } + ], + "source": [ + "from wandb.integration.sb3 import WandbCallback\n", + "\n", + "import gymnasium as gym\n", + "from stable_baselines3 import A2C\n", + "\n", + "\n", + "env = gym.make(\"CartPole-v1\", render_mode=\"human\")\n", + "\n", + "model = A2C(\"MlpPolicy\", env, verbose=0)\n", + "model.learn(total_timesteps=10_000, callback=WandbCallback())\n", + "\n", + "vec_env = model.get_env()\n", + "obs = vec_env.reset()\n", + "for i in tqdm(range(1000)):\n", + " action, _state = model.predict(obs, deterministic=True)\n", + " obs, reward, done, info = vec_env.step(action)\n", + " #vec_env.render(\"human\")\n", + " # VecEnv resets automatically\n", + " # if done:\n", + " # obs = vec_env.reset()\n", + "env.close()" + ] + }, { "cell_type": "markdown", "id": "9af0d167", diff --git a/Rendu_TP1_GEREST.ipynb b/Rendu_TP1_GEREST.ipynb index d05bc3577ae8b4d91d45b28ab935367fbaea9ff6..774b45bffb8eb61579386aa9b9034892469139d2 100644 --- a/Rendu_TP1_GEREST.ipynb +++ b/Rendu_TP1_GEREST.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "e687aca8", + "id": "182bf66e", "metadata": {}, "source": [ "GEREST CORENTIN\n", @@ -49,125 +49,46 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 1, "id": "c49497c0", - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: numpy==1.25.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (1.25.0)\n", - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } - ], - "source": [ - "pip install numpy==1.25.0 --user" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "bb4d7c39", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ "Requirement already satisfied: gym==0.26.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (0.26.2)\n", "Requirement already satisfied: importlib-metadata>=4.8.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (4.8.1)\n", - "Requirement already satisfied: numpy>=1.18.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (1.25.0)\n", - "Requirement already satisfied: cloudpickle>=1.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (2.0.0)\n", "Requirement already satisfied: gym-notices>=0.0.4 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (0.0.8)\n", + "Requirement already satisfied: cloudpickle>=1.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (2.0.0)\n", + "Requirement already satisfied: numpy>=1.18.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gym==0.26.2) (1.25.0)\n", "Requirement already satisfied: zipp>=0.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from importlib-metadata>=4.8.0->gym==0.26.2) (3.6.0)\n", - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } - ], - "source": [ - "pip install gym==0.26.2" - ] - }, - { - "cell_type": "markdown", - "id": "c4339dd2", - "metadata": {}, - "source": [ - "Install also pyglet for the rendering." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "ae74426f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ "Requirement already satisfied: pyglet==2.0.10 in c:\\users\\coren\\anaconda3\\lib\\site-packages (2.0.10)\n", - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } - ], - "source": [ - "pip install pyglet==2.0.10" - ] - }, - { - "cell_type": "markdown", - "id": "d6a2d90b", - "metadata": {}, - "source": [ - "If needed " - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "712fb75a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ "Requirement already satisfied: pygame==2.5.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (2.5.2)\n", - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } - ], - "source": [ - "pip install pygame==2.5.2" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "3cdd7bcc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ "Requirement already satisfied: PyQt5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (5.15.10)\n", "Requirement already satisfied: PyQt5-Qt5>=5.15.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from PyQt5) (5.15.2)\n", - "Note: you may need to restart the kernel to use updated packages.Requirement already satisfied: PyQt5-sip<13,>=12.13 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from PyQt5) (12.13.0)\n", - "\n" + "Requirement already satisfied: PyQt5-sip<13,>=12.13 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from PyQt5) (12.13.0)\n" ] } ], "source": [ - "pip install PyQt5" + "!pip install numpy==1.25.0 --user\n", + "!pip install gym==0.26.2\n", + "\n", + "# Install also pyglet for the rendering.\n", + "!pip install pyglet==2.0.10\n", + "\n", + "# If needed\n", + "!pip install pygame==2.5.2\n", + "!pip install PyQt5" ] }, { "cell_type": "markdown", - "id": "79581e82", + "id": "51bd01b6", "metadata": {}, "source": [ "### Usage\n", @@ -177,10 +98,23 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 2, "id": "800853bd", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\coren\\anaconda3\\lib\\site-packages\\numpy\\_distributor_init.py:30: UserWarning: loaded more than 1 DLL from .libs:\n", + "C:\\Users\\coren\\anaconda3\\lib\\site-packages\\numpy\\.libs\\libopenblas.4SP5SUA7CBGXUEOC35YP2ASOICYYEQZZ.gfortran-win_amd64.dll\n", + "C:\\Users\\coren\\anaconda3\\lib\\site-packages\\numpy\\.libs\\libopenblas64__v0.3.23-gcc_10_3_0.dll\n", + " warnings.warn(\"loaded more than 1 DLL from .libs:\"\n", + "C:\\Users\\coren\\anaconda3\\lib\\site-packages\\gym\\utils\\passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)\n", + " if not isinstance(terminated, (bool, np.bool8)):\n" + ] + } + ], "source": [ "import gym\n", "\n", @@ -208,7 +142,7 @@ }, { "cell_type": "markdown", - "id": "fea5e4cf", + "id": "5aa96f10", "metadata": {}, "source": [ "## REINFORCE\n", @@ -236,8 +170,168 @@ "To learn more about REINFORCE, you can refer to [this unit](https://huggingface.co/learn/deep-rl-course/unit4/introduction).\n", "\n", "> 🛠**To be handed in**\n", - "> Use PyTorch to implement REINFORCE and solve the CartPole environement. Share the code in `reinforce_cartpole.py`, and share a plot showing the total reward accross episodes in the `README.md`.\n", + "> Use PyTorch to implement REINFORCE and solve the CartPole environement. Share the code in `reinforce_cartpole.py`, and share a plot showing the total reward accross episodes in the `README.md`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f3e912fb", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c901dca8dc49499f91778a2439ab7250", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/500 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\coren\\AppData\\Local\\Temp/ipykernel_22904/2002979212.py:58: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\torch\\csrc\\utils\\tensor_new.cpp:264.)\n", + " state_tensor = torch.FloatTensor(state).unsqueeze(0)\n" + ] + }, + { + "ename": "ValueError", + "evalue": "expected sequence of length 4 at dim 1 (got 0)", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_22904/2002979212.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 56\u001b[0m \u001b[1;32mwhile\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 57\u001b[0m \u001b[1;31m# Compute action probabilities\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 58\u001b[1;33m \u001b[0mstate_tensor\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 59\u001b[0m \u001b[0maction_probs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpolicy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate_tensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 60\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mValueError\u001b[0m: expected sequence of length 4 at dim 1 (got 0)" + ] + } + ], + "source": [ + "import gym\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import numpy as np\n", + "from tqdm.notebook import tqdm\n", + "\n", + "\n", + "# Define the neural network for the policy\n", + "class PolicyNetwork(nn.Module):\n", + " def __init__(self, input_dim, output_dim):\n", + " super(PolicyNetwork, self).__init__()\n", + " self.fc1 = nn.Linear(input_dim, 128)\n", + " self.relu = nn.ReLU()\n", + " self.dropout = nn.Dropout(p=0.6)\n", + " self.fc2 = nn.Linear(128, output_dim)\n", + " self.softmax = nn.Softmax(dim=1)\n", + "\n", + " def forward(self, x):\n", + " x = self.fc1(x)\n", + " x = self.relu(x)\n", + " x = self.dropout(x)\n", + " x = self.fc2(x)\n", + " return self.softmax(x)\n", + "\n", + "# Normalize function\n", + "def normalize_rewards(rewards):\n", + " rewards = np.array(rewards)\n", + " rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-9)\n", + " return rewards\n", + "\n", + "# Hyperparameters\n", + "learning_rate = 5e-3\n", + "gamma = 0.99\n", + "episodes = 500\n", + "\n", + "# Environment setup\n", + "env = gym.make(\"CartPole-v1\", render_mode=\"human\")\n", + "\n", + "input_dim = env.observation_space.shape[0]\n", + "output_dim = env.action_space.n\n", + "\n", + "# Policy network\n", + "policy = PolicyNetwork(input_dim, output_dim)\n", + "optimizer = optim.Adam(policy.parameters(), lr=learning_rate)\n", + "\n", + "# Training loop\n", + "episode_rewards = []\n", + "for episode in tqdm(range(episodes)):\n", + "\n", + " state = env.reset()\n", + " episode_reward = 0\n", + " saved_log_probs = []\n", + " rewards = []\n", + "\n", + " while True:\n", + " # Compute action probabilities\n", + " state_tensor = torch.FloatTensor(state).unsqueeze(0)\n", + " action_probs = policy(state_tensor)\n", + "\n", + " # Sample action\n", + " action = torch.multinomial(action_probs, 1).item()\n", + " log_prob = torch.log(action_probs.squeeze(0)[action])\n", + " saved_log_probs.append(log_prob)\n", + "\n", + " # Step env with action\n", + " next_state, reward, done, _ = env.step(action)\n", + " rewards.append(reward)\n", + " episode_reward += reward\n", + "\n", + "\n", + " if done:\n", + " # Compute return\n", + " returns = []\n", + " R = 0\n", + " for r in rewards[::-1]:\n", + " R = r + gamma * R\n", + " returns.insert(0, R)\n", "\n", + " # Normalize returns\n", + " returns = normalize_rewards(returns)\n", + "\n", + " # Update policy\n", + " policy_loss = torch.stack(saved_log_probs) * torch.tensor(returns)\n", + " policy_loss = -policy_loss.sum()\n", + " optimizer.zero_grad()\n", + " policy_loss.backward()\n", + " optimizer.step()\n", + "\n", + " episode_rewards.append(episode_reward)\n", + " break\n", + "\n", + " state = next_state" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06261130", + "metadata": {}, + "outputs": [], + "source": [ + "# Plotting\n", + "import matplotlib.pyplot as plt\n", + "\n", + "plt.plot(episode_rewards)\n", + "plt.xlabel('Episode')\n", + "plt.ylabel('Total Reward')\n", + "plt.title('REINFORCE on CartPole')\n", + "plt.savefig('reinforce_cartpole_rewards2.png')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "fea5e4cf", + "metadata": {}, + "source": [ "## Familiarization with a complete RL pipeline: Application to training a robotic arm\n", "\n", "In this section, you will use the Stable-Baselines3 package to train a robotic arm using RL. You'll get familiar with several widely-used tools for training, monitoring and sharing machine learning models.\n", @@ -251,7 +345,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 4, "id": "7b5d4e63", "metadata": { "scrolled": true @@ -261,50 +355,45 @@ "name": "stdout", "output_type": "stream", "text": [ - "Collecting stable-baselines3\n", - " Using cached stable_baselines3-2.2.1-py3-none-any.whl (181 kB)\n", + "Requirement already satisfied: stable-baselines3 in c:\\users\\coren\\appdata\\roaming\\python\\python39\\site-packages (2.2.1)\n", + "Requirement already satisfied: gymnasium<0.30,>=0.28.1 in c:\\users\\coren\\appdata\\roaming\\python\\python39\\site-packages (from stable-baselines3) (0.29.1)\n", "Requirement already satisfied: matplotlib in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (3.4.3)\n", - "Requirement already satisfied: numpy>=1.20 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (1.26.4)\n", + "Requirement already satisfied: numpy>=1.20 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (1.25.0)\n", "Requirement already satisfied: pandas in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (1.3.4)\n", - "Requirement already satisfied: torch>=1.13 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (2.1.1)\n", "Requirement already satisfied: cloudpickle in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (2.0.0)\n", - "Collecting gymnasium<0.30,>=0.28.1\n", - " Using cached gymnasium-0.29.1-py3-none-any.whl (953 kB)\n", - "Collecting farama-notifications>=0.0.1\n", - " Using cached Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)\n", - "Requirement already satisfied: importlib-metadata>=4.8.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (4.8.1)\n", + "Requirement already satisfied: torch>=1.13 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from stable-baselines3) (2.1.1)\n", "Requirement already satisfied: typing-extensions>=4.3.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (4.4.0)\n", + "Requirement already satisfied: importlib-metadata>=4.8.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (4.8.1)\n", + "Requirement already satisfied: farama-notifications>=0.0.1 in c:\\users\\coren\\appdata\\roaming\\python\\python39\\site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (0.0.4)\n", "Requirement already satisfied: zipp>=0.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from importlib-metadata>=4.8.0->gymnasium<0.30,>=0.28.1->stable-baselines3) (3.6.0)\n", - "Requirement already satisfied: jinja2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2.11.3)\n", + "Requirement already satisfied: filelock in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (3.3.1)\n", "Requirement already satisfied: sympy in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (1.9)\n", - "Requirement already satisfied: fsspec in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2021.10.1)\n", "Requirement already satisfied: networkx in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2.6.3)\n", - "Requirement already satisfied: filelock in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (3.3.1)\n", + "Requirement already satisfied: jinja2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2.11.3)\n", + "Requirement already satisfied: fsspec in c:\\users\\coren\\anaconda3\\lib\\site-packages (from torch>=1.13->stable-baselines3) (2021.10.1)\n", "Requirement already satisfied: MarkupSafe>=0.23 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from jinja2->torch>=1.13->stable-baselines3) (1.1.1)\n", - "Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (2.8.2)\n", "Requirement already satisfied: pillow>=6.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (8.4.0)\n", - "Requirement already satisfied: pyparsing>=2.2.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (3.0.4)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (1.3.1)\n", + "Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (2.8.2)\n", + "Requirement already satisfied: pyparsing>=2.2.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (3.0.4)\n", "Requirement already satisfied: cycler>=0.10 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from matplotlib->stable-baselines3) (0.10.0)\n", "Requirement already satisfied: six in c:\\users\\coren\\anaconda3\\lib\\site-packages (from cycler>=0.10->matplotlib->stable-baselines3) (1.16.0)\n", "Requirement already satisfied: pytz>=2017.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from pandas->stable-baselines3) (2021.3)\n", "Requirement already satisfied: mpmath>=0.19 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from sympy->torch>=1.13->stable-baselines3) (1.2.1)\n", - "Installing collected packages: farama-notifications, gymnasium, stable-baselines3\n", - "Successfully installed farama-notifications-0.0.4 gymnasium-0.29.1 stable-baselines3-2.2.1\n", "Requirement already satisfied: moviepy in c:\\users\\coren\\anaconda3\\lib\\site-packages (1.0.3)\n", - "Requirement already satisfied: decorator<5.0,>=4.0.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (4.4.2)\n", + "Requirement already satisfied: imageio<3.0,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (2.9.0)\n", "Requirement already satisfied: requests<3.0,>=2.8.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (2.26.0)\n", - "Requirement already satisfied: imageio-ffmpeg>=0.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (0.4.9)\n", - "Requirement already satisfied: numpy>=1.17.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (1.26.4)\n", "Requirement already satisfied: proglog<=1.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (0.1.10)\n", - "Requirement already satisfied: imageio<3.0,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (2.9.0)\n", "Requirement already satisfied: tqdm<5.0,>=4.11.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (4.62.3)\n", + "Requirement already satisfied: imageio-ffmpeg>=0.2.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (0.4.9)\n", + "Requirement already satisfied: numpy>=1.17.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (1.25.0)\n", + "Requirement already satisfied: decorator<5.0,>=4.0.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from moviepy) (4.4.2)\n", "Requirement already satisfied: pillow in c:\\users\\coren\\anaconda3\\lib\\site-packages (from imageio<3.0,>=2.5->moviepy) (8.4.0)\n", "Requirement already satisfied: setuptools in c:\\users\\coren\\anaconda3\\lib\\site-packages (from imageio-ffmpeg>=0.2.0->moviepy) (58.0.4)\n", - "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (2021.10.8)\n", "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (3.2)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (1.26.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (2021.10.8)\n", "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (2.0.4)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3.0,>=2.8.1->moviepy) (1.26.7)\n", "Requirement already satisfied: colorama in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tqdm<5.0,>=4.11.2->moviepy) (0.4.4)\n" ] } @@ -316,30 +405,212 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 15, "id": "14c1bc63", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ac2e8543ce234f088dce3fec9f63d59a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/10000 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD4CAYAAAANbUbJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAO7UlEQVR4nO3cYazddX3H8fdnvVQE51rXO1Pb6i1J42yME3bDihpDxG2FGUmMD2jCcETTmIFDt8SAPjB7posxSmZgjVbHdKBD3BrChgtqyB4I3ApioXRWUHulrteYgdEsUP3uwflXT6733nN6e8qh5/d+JTc9/9/vf879/dry7jn/cy6pKiRJk+23xr0ASdLpZ+wlqQHGXpIaYOwlqQHGXpIaMDXuBSxlw4YNNTMzM+5lSNIZY//+/T+uqunl5p+XsZ+ZmWFubm7cy5CkM0aS768072UcSWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWrAwNgn2ZvkWJIDy8wnyY1JDid5OMkFi+bXJHkwyZ2jWrQk6eQM88z+s8DOFeYvBbZ1X7uBmxbNXwccXM3iJEmjMTD2VXUv8JMVTrkcuKV6vgGsS7IRIMlm4M+AT41isZKk1RnFNftNwJG+4/luDODjwPuBXw56kCS7k8wlmVtYWBjBsiRJJ4wi9llirJK8BThWVfuHeZCq2lNVs1U1Oz09PYJlSZJOGEXs54EtfcebgSeB1wNvTfI94DbgTUk+N4LvJ0k6SaOI/T7gqu5TOTuAp6rqaFXdUFWbq2oGuAL4alVdOYLvJ0k6SVODTkhyK3AxsCHJPPAh4CyAqroZuAu4DDgM/By4+nQtVpK0OgNjX1W7BswXcM2Ac74OfP1kFiZJGh1/glaSGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBxl6SGmDsJakBA2OfZG+SY0kOLDOfJDcmOZzk4SQXdONbknwtycEkjyS5btSLlyQNZ5hn9p8Fdq4wfymwrfvaDdzUjR8H/qaqXgXsAK5Jsn31S5UkrdbA2FfVvcBPVjjlcuCW6vkGsC7Jxqo6WlXf7B7jp8BBYNMoFi1JOjmjuGa/CTjSdzzPoqgnmQHOB+4bwfeTJJ2kUcQ+S4zVryaTFwFfAt5bVU8v+yDJ7iRzSeYWFhZGsCxJ0gmjiP08sKXveDPwJECSs+iF/vNVdcdKD1JVe6pqtqpmp6enR7AsSdIJo4j9PuCq7lM5O4CnqupokgCfBg5W1cdG8H0kSas0NeiEJLcCFwMbkswDHwLOAqiqm4G7gMuAw8DPgau7u74e+HPg20ke6sY+UFV3jXD9kqQhDIx9Ve0aMF/ANUuM/xdLX8+XJD3H/AlaSWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWqAsZekBhh7SWrAwNgn2ZvkWJIDy8wnyY1JDid5OMkFfXM7kxzq5q4f5cIlScMb5pn9Z4GdK8xfCmzrvnYDNwEkWQN8spvfDuxKsv1UFitJWp2pQSdU1b1JZlY45XLglqoq4BtJ1iXZCMwAh6vqcYAkt3XnPnrKq17Gdbc9yDPHf3m6Hl6STqsXn30WH3n7a07LYw+M/RA2AUf6jue7saXG/2i5B0mym94rA17+8pevaiFP/Phn/N+zv1jVfSVp3Nads/a0PfYoYp8lxmqF8SVV1R5gD8Ds7Oyy561k37VvWM3dJGnijSL288CWvuPNwJPA2mXGJUnPsVF89HIfcFX3qZwdwFNVdRR4ANiWZGuStcAV3bmSpOfYwGf2SW4FLgY2JJkHPgScBVBVNwN3AZcBh4GfA1d3c8eTXAvcDawB9lbVI6dhD5KkAYb5NM6uAfMFXLPM3F30/jGQJI2RP0ErSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUAGMvSQ0w9pLUgKFin2RnkkNJDie5fon59Um+nOThJPcneXXf3PuSPJLkQJJbk5w9yg1IkgYbGPska4BPApcC24FdSbYvOu0DwENV9RrgKuAT3X03AX8FzFbVq4E1wBWjW74kaRjDPLO/EDhcVY9X1TPAbcDli87ZDtwDUFWPATNJXtrNTQEvTDIFnAM8OZKVS5KGNkzsNwFH+o7nu7F+3wLeBpDkQuAVwOaq+iHwUeAHwFHgqar6yqkuWpJ0coaJfZYYq0XHHwbWJ3kIeA/wIHA8yXp6rwK2Ai8Dzk1y5ZLfJNmdZC7J3MLCwrDrlyQNYZjYzwNb+o43s+hSTFU9XVVXV9Vr6V2znwaeAN4MPFFVC1X1LHAH8LqlvklV7amq2aqanZ6ePvmdSJKWNUzsHwC2JdmaZC29N1j39Z+QZF03B/Au4N6qepre5ZsdSc5JEuAS4ODoli9JGsbUoBOq6niSa4G76X2aZm9VPZLk3d38zcCrgFuS/AJ4FHhnN3dfktuBbwLH6V3e2XNadiJJWlaqFl9+H7/Z2dmam5sb9zIk6YyRZH9VzS4370/QSlIDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDjL0kNcDYS1IDhop9kp1JDiU5nOT6JebXJ/lykoeT3J/k1X1z65LcnuSxJAeTXDTKDUiSBhsY+yRrgE8ClwLbgV1Jti867QPAQ1X1GuAq4BN9c58A/qOqfh/4A+DgKBYuSRreMM/sLwQOV9XjVfUMcBtw+aJztgP3AFTVY8BMkpcmeTHwRuDT3dwzVfW/o1q8JGk4w8R+E3Ck73i+G+v3LeBtAEkuBF4BbAbOAxaAzyR5MMmnkpy71DdJsjvJXJK5hYWFk9yGJGklw8Q+S4zVouMPA+uTPAS8B3gQOA5MARcAN1XV+cDPgN+45g9QVXuqaraqZqenp4dcviRpGFNDnDMPbOk73gw82X9CVT0NXA2QJMAT3dc5wHxV3dedejvLxF6SdPoM88z+AWBbkq1J1gJXAPv6T+g+cbO2O3wXcG9VPV1VPwKOJHllN3cJ8OiI1i5JGtLAZ/ZVdTzJtcDdwBpgb1U9kuTd3fzNwKuAW5L8gl7M39n3EO8BPt/9Y/A43SsASdJzJ1WLL7+P3+zsbM3NzY17GZJ0xkiyv6pml5v3J2glqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIaYOwlqQHGXpIakKoa9xp+Q5IF4PurvPsG4McjXM6ZwD1Pvtb2C+75ZL2iqqaXm3xexv5UJJmrqtlxr+O55J4nX2v7Bfc8al7GkaQGGHtJasAkxn7PuBcwBu558rW2X3DPIzVx1+wlSb9pEp/ZS5IWMfaS1ICJiX2SnUkOJTmc5Ppxr+dUJNmS5GtJDiZ5JMl13fhLkvxnku90v67vu88N3d4PJfnTvvE/TPLtbu7GJBnHnoaRZE2SB5Pc2R1P+n7XJbk9yWPdn/VFDez5fd3f6QNJbk1y9qTtOcneJMeSHOgbG9kek7wgyRe68fuSzAy1sKo647+ANcB3gfOAtcC3gO3jXtcp7GcjcEF3+7eB/wa2A38HXN+NXw98pLu9vdvzC4Ct3e/Fmm7ufuAiIMC/A5eOe38r7PuvgX8G7uyOJ32//wi8q7u9Flg3yXsGNgFPAC/sjr8I/MWk7Rl4I3ABcKBvbGR7BP4SuLm7fQXwhaHWNe7fmBH95l4E3N13fANww7jXNcL9/Rvwx8AhYGM3thE4tNR+gbu735ONwGN947uAfxj3fpbZ42bgHuBN/Dr2k7zfF3fhy6LxSd7zJuAI8BJgCrgT+JNJ3DMwsyj2I9vjiXO621P0fuI2g9Y0KZdxTvwlOmG+GzvjdS/RzgfuA15aVUcBul9/rzttuf1v6m4vHn8++jjwfuCXfWOTvN/zgAXgM92lq08lOZcJ3nNV/RD4KPAD4CjwVFV9hQnec59R7vFX96mq48BTwO8OWsCkxH6p63Vn/GdKk7wI+BLw3qp6eqVTlxirFcafV5K8BThWVfuHvcsSY2fMfjtT9F7q31RV5wM/o/fyfjln/J6769SX07tc8TLg3CRXrnSXJcbOqD0PYTV7XNX+JyX288CWvuPNwJNjWstIJDmLXug/X1V3dMP/k2RjN78RONaNL7f/+e724vHnm9cDb03yPeA24E1JPsfk7hd6a52vqvu649vpxX+S9/xm4ImqWqiqZ4E7gNcx2Xs+YZR7/NV9kkwBvwP8ZNACJiX2DwDbkmxNspbemxb7xrymVevedf80cLCqPtY3tQ94R3f7HfSu5Z8Yv6J7l34rsA24v3u5+NMkO7rHvKrvPs8bVXVDVW2uqhl6f3ZfraormdD9AlTVj4AjSV7ZDV0CPMoE75ne5ZsdSc7p1noJcJDJ3vMJo9xj/2O9nd5/L4Nf2Yz7jYwRviFyGb1PrXwX+OC413OKe3kDvZdlDwMPdV+X0bsudw/wne7Xl/Td54Pd3g/R98kEYBY40M39PUO8kTPmvV/Mr9+gnej9Aq8F5ro/538F1jew578FHuvW+0/0PoUyUXsGbqX3nsSz9J6Fv3OUewTOBv4FOEzvEzvnDbMu/3cJktSASbmMI0lagbGXpAYYe0lqgLGXpAYYe0lqgLGXpAYYe0lqwP8DUCQh/QV429IAAAAASUVORK5CYII=\n", + "text/plain": [ + "<Figure size 432x288 with 1 Axes>" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ "import gymnasium as gym\n", "from stable_baselines3 import A2C\n", + "import matplotlib.pyplot as plt\n", "\n", - "\n", - "env = gym.make(\"CartPole-v1\", render_mode=\"human\")\n", + "env = gym.make(\"CartPole-v1\", render_mode=\"rgb_array\")#, render_mode=\"human\")\n", "\n", "model = A2C(\"MlpPolicy\", env, verbose=0)\n", "model.learn(total_timesteps=10_000)\n", "\n", "vec_env = model.get_env()\n", "obs = vec_env.reset()\n", - "for i in range(1000):\n", + "rewards = []\n", + "for i in tqdm(range(10000)):\n", " action, _state = model.predict(obs, deterministic=True)\n", " obs, reward, done, info = vec_env.step(action)\n", + " rewards.append(reward)\n", " #vec_env.render(\"human\")\n", " # VecEnv resets automatically\n", " # if done:\n", " # obs = vec_env.reset()\n", - "env.close()" + "env.close()\n", + "\n", + "plt.plot(rewards)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "42f0141f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using cpu device\n" + ] + }, + { + "ename": "ImportError", + "evalue": "Missing shimmy installation. You provided an OpenAI Gym environment. Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. In order to use OpenAI Gym environments with SB3, you need to install shimmy (`pip install 'shimmy>=0.2.1'`).", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\common\\vec_env\\patch_gym.py\u001b[0m in \u001b[0;36m_patch_env\u001b[1;34m(env)\u001b[0m\n\u001b[0;32m 39\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 40\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0mshimmy\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 41\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'shimmy'", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[1;31mImportError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_22904/716802256.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0menv\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgym\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"CartPole-v1\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 7\u001b[1;33m \u001b[0mmodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mA2C\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"MlpPolicy\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 8\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 9\u001b[0m \u001b[1;31m# Lists to store rewards during training\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\a2c\\a2c.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, policy, env, learning_rate, n_steps, gamma, gae_lambda, ent_coef, vf_coef, max_grad_norm, rms_prop_eps, use_rms_prop, use_sde, sde_sample_freq, rollout_buffer_class, rollout_buffer_kwargs, normalize_advantage, stats_window_size, tensorboard_log, policy_kwargs, verbose, seed, device, _init_setup_model)\u001b[0m\n\u001b[0;32m 90\u001b[0m \u001b[0m_init_setup_model\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 91\u001b[0m ):\n\u001b[1;32m---> 92\u001b[1;33m super().__init__(\n\u001b[0m\u001b[0;32m 93\u001b[0m \u001b[0mpolicy\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 94\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\common\\on_policy_algorithm.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, policy, env, learning_rate, n_steps, gamma, gae_lambda, ent_coef, vf_coef, max_grad_norm, use_sde, sde_sample_freq, rollout_buffer_class, rollout_buffer_kwargs, stats_window_size, tensorboard_log, monitor_wrapper, policy_kwargs, verbose, seed, device, _init_setup_model, supported_action_spaces)\u001b[0m\n\u001b[0;32m 83\u001b[0m \u001b[0msupported_action_spaces\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mTuple\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mType\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mspaces\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mSpace\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m...\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 84\u001b[0m ):\n\u001b[1;32m---> 85\u001b[1;33m super().__init__(\n\u001b[0m\u001b[0;32m 86\u001b[0m \u001b[0mpolicy\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mpolicy\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 87\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\common\\base_class.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, policy, env, learning_rate, policy_kwargs, stats_window_size, tensorboard_log, verbose, device, support_multi_env, monitor_wrapper, seed, use_sde, sde_sample_freq, supported_action_spaces)\u001b[0m\n\u001b[0;32m 167\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0menv\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 168\u001b[0m \u001b[0menv\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmaybe_make_env\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mverbose\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 169\u001b[1;33m \u001b[0menv\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_wrap_env\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mverbose\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmonitor_wrapper\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 170\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 171\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mobservation_space\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mobservation_space\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\common\\base_class.py\u001b[0m in \u001b[0;36m_wrap_env\u001b[1;34m(env, verbose, monitor_wrapper)\u001b[0m\n\u001b[0;32m 214\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mVecEnv\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 215\u001b[0m \u001b[1;31m# Patch to support gym 0.21/0.26 and gymnasium\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 216\u001b[1;33m \u001b[0menv\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_patch_env\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 217\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mis_wrapped\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mMonitor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mmonitor_wrapper\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 218\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mverbose\u001b[0m \u001b[1;33m>=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\common\\vec_env\\patch_gym.py\u001b[0m in \u001b[0;36m_patch_env\u001b[1;34m(env)\u001b[0m\n\u001b[0;32m 40\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mshimmy\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 41\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 42\u001b[1;33m raise ImportError(\n\u001b[0m\u001b[0;32m 43\u001b[0m \u001b[1;34m\"Missing shimmy installation. You provided an OpenAI Gym environment. \"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 44\u001b[0m \u001b[1;34m\"Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. \"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mImportError\u001b[0m: Missing shimmy installation. You provided an OpenAI Gym environment. Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. In order to use OpenAI Gym environments with SB3, you need to install shimmy (`pip install 'shimmy>=0.2.1'`)." + ] + } + ], + "source": [ + "import gym\n", + "from stable_baselines3 import A2C\n", + "import matplotlib.pyplot as plt\n", + "\n", + "env = gym.make(\"CartPole-v1\")\n", + "\n", + "model = A2C(\"MlpPolicy\", env, verbose=1)\n", + "\n", + "# Lists to store rewards during training\n", + "episode_rewards = []\n", + "\n", + "total_timesteps = 10000\n", + "for timestep in range(total_timesteps):\n", + " # Train for one step\n", + " obs = env.reset()\n", + " episode_reward = 0\n", + " done = False\n", + " while not done:\n", + " action, _ = model.predict(obs, deterministic=True)\n", + " obs, reward, done, _ = env.step(action)\n", + " episode_reward += reward\n", + " \n", + " # Log episode reward\n", + " episode_rewards.append(episode_reward)\n", + "\n", + " # Display training progress\n", + " if timestep % 1000 == 0:\n", + " print(f\"Timestep: {timestep}/{total_timesteps}\")\n", + "\n", + "# Plot rewards over training\n", + "plt.plot(episode_rewards)\n", + "plt.xlabel('Episode')\n", + "plt.ylabel('Episode Reward')\n", + "plt.title('Evolution of Episode Reward during Training')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "6d4703d4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using cpu device\n" + ] + }, + { + "ename": "AttributeError", + "evalue": "'A2C' object has no attribute '_logger'", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_22904/1163910362.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 19\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mtimestep\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtotal_timesteps\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 20\u001b[0m \u001b[1;31m# Train for one step\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 21\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 22\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 23\u001b[0m \u001b[1;31m# Log mean episode reward every 1000 timesteps\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\a2c\\a2c.py\u001b[0m in \u001b[0;36mtrain\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 139\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 140\u001b[0m \u001b[1;31m# Update optimizer learning rate\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 141\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_update_learning_rate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpolicy\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 142\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 143\u001b[0m \u001b[1;31m# This will only loop once (get all data in one go)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\common\\base_class.py\u001b[0m in \u001b[0;36m_update_learning_rate\u001b[1;34m(self, optimizers)\u001b[0m\n\u001b[0;32m 293\u001b[0m \"\"\"\n\u001b[0;32m 294\u001b[0m \u001b[1;31m# Log the current learning rate\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 295\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlogger\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrecord\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"train/learning_rate\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlr_schedule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_current_progress_remaining\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 296\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 297\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moptimizers\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python39\\site-packages\\stable_baselines3\\common\\base_class.py\u001b[0m in \u001b[0;36mlogger\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 269\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mlogger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mLogger\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 270\u001b[0m \u001b[1;34m\"\"\"Getter for the logger object.\"\"\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 271\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_logger\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 272\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 273\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_setup_lr_schedule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mAttributeError\u001b[0m: 'A2C' object has no attribute '_logger'" + ] + } + ], + "source": [ + "import gym\n", + "from stable_baselines3 import A2C\n", + "from stable_baselines3.common.env_util import make_vec_env\n", + "from stable_baselines3.common.evaluation import evaluate_policy\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Create and wrap the CartPole environment\n", + "env = make_vec_env('CartPole-v1', n_envs=4)\n", + "\n", + "# Define the A2C model\n", + "model = A2C('MlpPolicy', env, verbose=1)\n", + "\n", + "# Lists to store metrics\n", + "mean_rewards = []\n", + "entropies = []\n", + "\n", + "# Train the model\n", + "total_timesteps = 10000\n", + "for timestep in range(total_timesteps):\n", + " # Train for one step\n", + " model.train()\n", + " \n", + " # Log mean episode reward every 1000 timesteps\n", + " if timestep % 1000 == 0:\n", + " mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10)\n", + " mean_rewards.append(mean_reward)\n", + " \n", + " # Log entropy every 100 timesteps\n", + " if timestep % 100 == 0:\n", + " action_probs = model.policy.action_net(torch.FloatTensor(env.reset()))\n", + " entropy = -(action_probs * action_probs.log()).sum(dim=-1).mean().item()\n", + " entropies.append(entropy)\n", + "\n", + "# Evaluate the model\n", + "mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10)\n", + "print(f\"Mean reward: {mean_reward}\")\n", + "\n", + "# Plotting\n", + "def plot_metrics(metrics, ylabel, title):\n", + " plt.plot(metrics)\n", + " plt.xlabel('Timestep')\n", + " plt.ylabel(ylabel)\n", + " plt.title(title)\n", + " plt.show()\n", + "\n", + "plot_metrics(mean_rewards, 'Mean Reward', 'Mean Reward Over Time')\n", + "plot_metrics(entropies, 'Entropy', 'Entropy of Action Distribution Over Time')" ] }, { @@ -364,10 +635,54 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "cd890835", - "metadata": {}, - "outputs": [], + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting huggingface-sb3==2.3.1\n", + " Downloading huggingface_sb3-2.3.1-py3-none-any.whl (9.5 kB)\n", + "Collecting huggingface-hub~=0.8\n", + " Downloading huggingface_hub-0.20.3-py3-none-any.whl (330 kB)\n", + "Requirement already satisfied: pyyaml~=6.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-sb3==2.3.1) (6.0)\n", + "Collecting wasabi\n", + " Downloading wasabi-1.1.2-py3-none-any.whl (27 kB)\n", + "Requirement already satisfied: cloudpickle>=1.6 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-sb3==2.3.1) (2.0.0)\n", + "Requirement already satisfied: numpy in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-sb3==2.3.1) (1.25.0)\n", + "Requirement already satisfied: packaging>=20.9 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (21.0)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (4.4.0)\n", + "Requirement already satisfied: requests in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (2.26.0)\n", + "Requirement already satisfied: filelock in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (3.3.1)\n", + "Requirement already satisfied: tqdm>=4.42.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from huggingface-hub~=0.8->huggingface-sb3==2.3.1) (4.62.3)\n", + "Collecting fsspec>=2023.5.0\n", + " Downloading fsspec-2024.2.0-py3-none-any.whl (170 kB)\n", + "Requirement already satisfied: pyparsing>=2.0.2 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from packaging>=20.9->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (3.0.4)\n", + "Requirement already satisfied: colorama in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tqdm>=4.42.1->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (0.4.4)\n", + "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (2021.10.8)\n", + "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (2.0.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (3.2)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests->huggingface-hub~=0.8->huggingface-sb3==2.3.1) (1.26.7)\n", + "Collecting colorama\n", + " Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)\n", + "Installing collected packages: colorama, fsspec, wasabi, huggingface-hub, huggingface-sb3\n", + " Attempting uninstall: colorama\n", + " Found existing installation: colorama 0.4.4\n", + " Uninstalling colorama-0.4.4:\n", + " Successfully uninstalled colorama-0.4.4\n", + " Attempting uninstall: fsspec\n", + " Found existing installation: fsspec 2021.10.1\n", + " Uninstalling fsspec-2021.10.1:\n", + " Successfully uninstalled fsspec-2021.10.1\n", + "Successfully installed colorama-0.4.6 fsspec-2024.2.0 huggingface-hub-0.20.3 huggingface-sb3-2.3.1 wasabi-1.1.2\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], "source": [ "pip install huggingface-sb3==2.3.1" ] @@ -398,14 +713,238 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "6645c23a", - "metadata": {}, - "outputs": [], + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting wandb\n", + " Downloading wandb-0.16.3-py3-none-any.whl (2.2 MB)\n", + "Collecting tensorboard\n", + " Downloading tensorboard-2.16.2-py3-none-any.whl (5.5 MB)\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "conda-repo-cli 1.0.4 requires pathlib, which is not installed.\n", + "anaconda-project 0.10.1 requires ruamel-yaml, which is not installed.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: requests<3,>=2.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (2.26.0)\n", + "Requirement already satisfied: setuptools in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (58.0.4)\n", + "Requirement already satisfied: appdirs>=1.4.3 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (1.4.4)\n", + "Requirement already satisfied: PyYAML in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (6.0)\n", + "Requirement already satisfied: typing-extensions in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (4.4.0)\n", + "Collecting GitPython!=3.1.29,>=1.0.0\n", + " Downloading GitPython-3.1.42-py3-none-any.whl (195 kB)\n", + "Collecting sentry-sdk>=1.0.0\n", + " Downloading sentry_sdk-1.40.5-py2.py3-none-any.whl (258 kB)\n", + "Requirement already satisfied: Click!=8.0.0,>=7.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (8.0.3)\n", + "Collecting setproctitle\n", + " Downloading setproctitle-1.3.3-cp39-cp39-win_amd64.whl (11 kB)\n", + "Collecting protobuf!=4.21.0,<5,>=3.19.0\n", + " Downloading protobuf-4.25.3-cp39-cp39-win_amd64.whl (413 kB)\n", + "Requirement already satisfied: psutil>=5.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from wandb) (5.8.0)\n", + "Collecting docker-pycreds>=0.4.0\n", + " Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n", + "Requirement already satisfied: six>1.9 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tensorboard) (1.16.0)\n", + "Collecting tensorboard-data-server<0.8.0,>=0.7.0\n", + " Downloading tensorboard_data_server-0.7.2-py3-none-any.whl (2.4 kB)\n", + "Requirement already satisfied: werkzeug>=1.0.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tensorboard) (2.0.2)\n", + "Collecting markdown>=2.6.8\n", + " Downloading Markdown-3.5.2-py3-none-any.whl (103 kB)\n", + "Collecting absl-py>=0.4\n", + " Downloading absl_py-2.1.0-py3-none-any.whl (133 kB)\n", + "Requirement already satisfied: numpy>=1.12.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from tensorboard) (1.25.0)\n", + "Collecting grpcio>=1.48.2\n", + " Downloading grpcio-1.62.0-cp39-cp39-win_amd64.whl (3.8 MB)\n", + "Requirement already satisfied: colorama in c:\\users\\coren\\anaconda3\\lib\\site-packages (from Click!=8.0.0,>=7.1->wandb) (0.4.6)\n", + "Collecting gitdb<5,>=4.0.1\n", + " Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n", + "Collecting smmap<6,>=3.0.1\n", + " Downloading smmap-5.0.1-py3-none-any.whl (24 kB)\n", + "Requirement already satisfied: importlib-metadata>=4.4 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from markdown>=2.6.8->tensorboard) (4.8.1)\n", + "Requirement already satisfied: zipp>=0.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard) (3.6.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (2021.10.8)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (1.26.7)\n", + "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (2.0.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\coren\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (3.2)\n", + "Collecting urllib3<1.27,>=1.21.1\n", + " Downloading urllib3-1.26.18-py2.py3-none-any.whl (143 kB)\n", + "Installing collected packages: smmap, urllib3, gitdb, tensorboard-data-server, setproctitle, sentry-sdk, protobuf, markdown, grpcio, GitPython, docker-pycreds, absl-py, wandb, tensorboard\n", + " Attempting uninstall: urllib3\n", + " Found existing installation: urllib3 1.26.7\n", + " Uninstalling urllib3-1.26.7:\n", + " Successfully uninstalled urllib3-1.26.7\n", + "Successfully installed GitPython-3.1.42 absl-py-2.1.0 docker-pycreds-0.4.0 gitdb-4.0.11 grpcio-1.62.0 markdown-3.5.2 protobuf-4.25.3 sentry-sdk-1.40.5 setproctitle-1.3.3 smmap-5.0.1 tensorboard-2.16.2 tensorboard-data-server-0.7.2 urllib3-1.26.18 wandb-0.16.3\n" + ] + } + ], "source": [ "pip install wandb tensorboard" ] }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f94466df", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:\n", + "Traceback (most recent call last):\n", + " File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\", line 1172, in init\n", + " wi.setup(kwargs)\n", + " File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\", line 306, in setup\n", + " wandb_login._login(\n", + " File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\", line 317, in _login\n", + " wlogin.prompt_api_key()\n", + " File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\", line 240, in prompt_api_key\n", + " key, status = self._prompt_api_key()\n", + " File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\", line 220, in _prompt_api_key\n", + " key = apikey.prompt_api_key(\n", + " File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\wandb\\sdk\\lib\\apikey.py\", line 151, in prompt_api_key\n", + " key = input_callback(api_ask).strip()\n", + " File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\click\\termui.py\", line 168, in prompt\n", + " value = prompt_func(prompt)\n", + " File \"C:\\Users\\coren\\anaconda3\\lib\\site-packages\\click\\termui.py\", line 150, in prompt_func\n", + " raise Abort() from None\n", + "click.exceptions.Abort\n" + ] + }, + { + "ename": "Error", + "evalue": "An unexpected error occurred", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mAbort\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\u001b[0m in \u001b[0;36minit\u001b[1;34m(job_type, dir, config, project, entity, reinit, tags, group, name, notes, magic, config_exclude_keys, config_include_keys, anonymous, mode, allow_val_change, resume, force, tensorboard, sync_tensorboard, monitor_gym, save_code, id, settings)\u001b[0m\n\u001b[0;32m 1171\u001b[0m \u001b[0mwi\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_WandbInit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1172\u001b[1;33m \u001b[0mwi\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msetup\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1173\u001b[0m \u001b[1;32massert\u001b[0m \u001b[0mwi\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msettings\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\u001b[0m in \u001b[0;36msetup\u001b[1;34m(self, kwargs)\u001b[0m\n\u001b[0;32m 305\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0msettings\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_offline\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0msettings\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_noop\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 306\u001b[1;33m wandb_login._login(\n\u001b[0m\u001b[0;32m 307\u001b[0m \u001b[0manonymous\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"anonymous\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\u001b[0m in \u001b[0;36m_login\u001b[1;34m(anonymous, key, relogin, host, force, timeout, _backend, _silent, _disable_warning, _entity)\u001b[0m\n\u001b[0;32m 316\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mkey\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 317\u001b[1;33m \u001b[0mwlogin\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mprompt_api_key\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 318\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\u001b[0m in \u001b[0;36mprompt_api_key\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 239\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mprompt_api_key\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 240\u001b[1;33m \u001b[0mkey\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstatus\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_prompt_api_key\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 241\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mstatus\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mApiKeyStatus\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mNOTTY\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_login.py\u001b[0m in \u001b[0;36m_prompt_api_key\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 219\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 220\u001b[1;33m key = apikey.prompt_api_key(\n\u001b[0m\u001b[0;32m 221\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_settings\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\lib\\apikey.py\u001b[0m in \u001b[0;36mprompt_api_key\u001b[1;34m(settings, api, input_callback, browser_callback, no_offline, no_create, local)\u001b[0m\n\u001b[0;32m 150\u001b[0m )\n\u001b[1;32m--> 151\u001b[1;33m \u001b[0mkey\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minput_callback\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mapi_ask\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstrip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 152\u001b[0m \u001b[0mwrite_key\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msettings\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mapi\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mapi\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\click\\termui.py\u001b[0m in \u001b[0;36mprompt\u001b[1;34m(text, default, hide_input, confirmation_prompt, type, value_proc, prompt_suffix, show_default, err, show_choices)\u001b[0m\n\u001b[0;32m 167\u001b[0m \u001b[1;32mwhile\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 168\u001b[1;33m \u001b[0mvalue\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mprompt_func\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprompt\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 169\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mvalue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\click\\termui.py\u001b[0m in \u001b[0;36mprompt_func\u001b[1;34m(text)\u001b[0m\n\u001b[0;32m 149\u001b[0m \u001b[0mecho\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0merr\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0merr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 150\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mAbort\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 151\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mAbort\u001b[0m: ", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[1;31mError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_22904/2820517193.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[1;34m\"env_name\"\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;34m\"CartPole-v1\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 13\u001b[0m }\n\u001b[1;32m---> 14\u001b[1;33m run = wandb.init(\n\u001b[0m\u001b[0;32m 15\u001b[0m \u001b[0mproject\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m\"sb3\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\sdk\\wandb_init.py\u001b[0m in \u001b[0;36minit\u001b[1;34m(job_type, dir, config, project, entity, reinit, tags, group, name, notes, magic, config_exclude_keys, config_include_keys, anonymous, mode, allow_val_change, resume, force, tensorboard, sync_tensorboard, monitor_gym, save_code, id, settings)\u001b[0m\n\u001b[0;32m 1212\u001b[0m \u001b[0mwandb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtermerror\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Abnormal program exit\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1213\u001b[0m \u001b[0mos\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_exit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1214\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"An unexpected error occurred\"\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0merror_seen\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1215\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mrun\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mError\u001b[0m: An unexpected error occurred" + ] + } + ], + "source": [ + "import gym\n", + "from stable_baselines3 import PPO\n", + "from stable_baselines3.common.monitor import Monitor\n", + "from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder\n", + "import wandb\n", + "from wandb.integration.sb3 import WandbCallback\n", + "\n", + "\n", + "config = {\n", + " \"policy_type\": \"MlpPolicy\",\n", + " \"total_timesteps\": 25000,\n", + " \"env_name\": \"CartPole-v1\",\n", + "}\n", + "run = wandb.init(\n", + " project=\"sb3\",\n", + " config=config,\n", + " sync_tensorboard=True, # auto-upload sb3's tensorboard metrics\n", + " monitor_gym=True, # auto-upload the videos of agents playing the game\n", + " save_code=True, # optional\n", + ")\n", + "\n", + "\n", + "def make_env():\n", + " env = gym.make(config[\"env_name\"])\n", + " env = Monitor(env) # record stats such as returns\n", + " return env\n", + "\n", + "\n", + "env = DummyVecEnv([make_env])\n", + "env = VecVideoRecorder(\n", + " env,\n", + " f\"videos/{run.id}\",\n", + " record_video_trigger=lambda x: x % 2000 == 0,\n", + " video_length=200,\n", + ")\n", + "model = PPO(config[\"policy_type\"], env, verbose=1, tensorboard_log=f\"runs/{run.id}\")\n", + "model.learn(\n", + " total_timesteps=config[\"total_timesteps\"],\n", + " callback=WandbCallback(\n", + " gradient_save_freq=100,\n", + " model_save_path=f\"models/{run.id}\",\n", + " verbose=2,\n", + " ),\n", + ")\n", + "run.finish()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4dda548e", + "metadata": {}, + "outputs": [ + { + "ename": "Error", + "evalue": "You must call wandb.init() before WandbCallback()", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_22904/2131705787.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 8\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 9\u001b[0m \u001b[0mmodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mA2C\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"MlpPolicy\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 10\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlearn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtotal_timesteps\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m10_000\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mWandbCallback\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 11\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[0mvec_env\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_env\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\wandb\\integration\\sb3\\sb3.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, verbose, model_save_path, model_save_freq, gradient_save_freq, log)\u001b[0m\n\u001b[0;32m 95\u001b[0m \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mverbose\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 96\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mwandb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 97\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mwandb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"You must call wandb.init() before WandbCallback()\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 98\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mwb_telemetry\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcontext\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mtel\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[0mtel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfeature\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msb3\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mError\u001b[0m: You must call wandb.init() before WandbCallback()" + ] + } + ], + "source": [ + "from wandb.integration.sb3 import WandbCallback\n", + "\n", + "import gymnasium as gym\n", + "from stable_baselines3 import A2C\n", + "\n", + "\n", + "env = gym.make(\"CartPole-v1\", render_mode=\"human\")\n", + "\n", + "model = A2C(\"MlpPolicy\", env, verbose=0)\n", + "model.learn(total_timesteps=10_000, callback=WandbCallback())\n", + "\n", + "vec_env = model.get_env()\n", + "obs = vec_env.reset()\n", + "for i in tqdm(range(1000)):\n", + " action, _state = model.predict(obs, deterministic=True)\n", + " obs, reward, done, info = vec_env.step(action)\n", + " #vec_env.render(\"human\")\n", + " # VecEnv resets automatically\n", + " # if done:\n", + " # obs = vec_env.reset()\n", + "env.close()" + ] + }, { "cell_type": "markdown", "id": "9af0d167",