From fa4e312ef63f69815684d7fd4ecef58af3697c4c Mon Sep 17 00:00:00 2001
From: number_cruncher <lennart.oestreich@stud.tu-darmstadt.de>
Date: Sun, 16 Mar 2025 17:25:46 +0100
Subject: [PATCH] readme and model rename

---
 README.md                         | 168 +-----------------------------
 a2c_sb3_cartpole.py               |   3 -
 a2c_sb3_panda_reach.py            |  48 +++++++++
 evaluate_reinforce_cartpole.py    |  41 ++++++++
 evaluate_reinforce_cartpole.py.py |  32 ------
 policy.pth                        | Bin 5512 -> 0 bytes
 reinforce_cartpole.pth            | Bin 0 -> 5800 bytes
 reinforce_cartpole.py             |  79 ++++++--------
 8 files changed, 121 insertions(+), 250 deletions(-)
 create mode 100644 a2c_sb3_panda_reach.py
 create mode 100644 evaluate_reinforce_cartpole.py
 delete mode 100644 evaluate_reinforce_cartpole.py.py
 delete mode 100644 policy.pth
 create mode 100644 reinforce_cartpole.pth

diff --git a/README.md b/README.md
index d1802c0..329e861 100644
--- a/README.md
+++ b/README.md
@@ -4,103 +4,11 @@ MSO 3.4 Apprentissage Automatique
 
 # 
 
-In this hands-on project, we will first implement a simple RL algorithm and apply it to solve the CartPole-v1 environment. Once we become familiar with the basic workflow, we will learn to use various tools for machine learning model training, monitoring, and sharing, by applying these tools to train a robotic arm.
-
-## To be handed in
-
-This work must be done individually. The expected output is a repository named `hands-on-rl` on https://gitlab.ec-lyon.fr. 
-
-We assume that `git` is installed, and that you are familiar with the basic `git` commands. (Optionnaly, you can use GitHub Desktop.)
-We also assume that you have access to the [ECL GitLab](https://gitlab.ec-lyon.fr/). If necessary, please consult [this tutorial](https://gitlab.ec-lyon.fr/edelland/inf_tc2/-/blob/main/Tutoriel_gitlab/tutoriel_gitlab.md).
-
-Your repository must contain a `README.md` file that explains **briefly** the successive steps of the project. It must be private, so you need to add your teacher as "developer" member.
-
-Throughout the subject, you will find a 🛠 symbol indicating that a specific production is expected.
-
-The last commit is due before 11:59 pm on March 17, 2025. Subsequent commits will not be considered.
-
-> ⚠️ **Warning**
-> Ensure that you only commit the files that are requested. For example, your directory should not contain the generated `.zip` files, nor the `runs` folder... At the end, your repository must contain one `README.md`, three python scripts, and optionally image files for the plots.
-
-## Before you start
-
-Make sure you know the basics of Reinforcement Learning. In case of need, you can refer to the [introduction of the Hugging Face RL course](https://huggingface.co/blog/deep-rl-intro).
-
-## Introduction to Gym
-
-[Gym](https://gymnasium.farama.org/) is a framework for developing and evaluating reinforcement learning environments. It offers various environments, including classic control and toy text scenarios, to test RL algorithms.
-
-### Installation
-
-We recommend to use Python virtual environnements to install the required modules : https://docs.python.org/3/library/venv.html or https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html if you are using conda.
-
-First, install Pytorch : https://pytorch.org/get-started/locally.
-
-Then install the following modules :
-
-
-```sh
-pip install gymnasium
-```
-
-```sh
-pip install "gymnasium[classic-control]"
-```
-
-
-### Usage
-
-Here is an example of how to use Gym to solve the `CartPole-v1` environment [Documentation](https://gymnasium.farama.org/environments/classic_control/cart_pole/):
-
-```python
-import gymnasium as gym
-
-# Create the environment
-env = gym.make("CartPole-v1", render_mode="human")
-
-# Reset the environment and get the initial observation
-observation = env.reset()
-
-for _ in range(100):
-    # Select a random action from the action space
-    action = env.action_space.sample()
-    # Apply the action to the environment
-    # Returns next observation, reward, done signal (indicating
-    # if the episode has ended), and an additional info dictionary
-    observation, reward, terminated, truncated, info = env.step(action)
-    # Render the environment to visualize the agent's behavior
-    env.render()
-    if terminated: 
-        # Terminated before max step
-        break
-
-env.close()
-```
 
 ## REINFORCE
 
 The REINFORCE algorithm (also known as Vanilla Policy Gradient) is a policy gradient method that optimizes the policy directly using gradient descent. The following is the pseudocode of the REINFORCE algorithm:
 
-```txt
-Setup the CartPole environment
-Setup the agent as a simple neural network with:
-    - One fully connected layer with 128 units and ReLU activation followed by a dropout layer
-    - One fully connected layer followed by softmax activation
-Repeat 500 times:
-    Reset the environment
-    Reset the buffer
-    Repeat until the end of the episode:
-        Compute action probabilities 
-        Sample the action based on the probabilities and store its probability in the buffer 
-        Step the environment with the action
-        Compute and store in the buffer the return using gamma=0.99 
-    Normalize the return
-    Compute the policy loss as -sum(log(prob) * return)
-    Update the policy using an Adam optimizer and a learning rate of 5e-3
-Save the model weights
-```
-
-To learn more about REINFORCE, you can refer to [this unit](https://huggingface.co/learn/deep-rl-course/unit4/policy-gradient).
 
 > 🛠 **To be handed in**
 > Use PyTorch to implement REINFORCE and solve the CartPole environement. Share the code in `reinforce_cartpole.py`, and share a plot showing the total reward accross episodes in the `README.md`. Also, share a file `reinforce_cartpole.pth` containing the learned weights. For saving and loading PyTorch models, check [this tutorial](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference)
@@ -112,70 +20,21 @@ Now that you have trained your model, it is time to evaluate its performance. Ru
 > 🛠 **To be handed in**
 > Implement a script which loads your saved model and use it to solve the cartpole enviroment. Run 100 evaluations and share the final success rate across all evaluations in the `README.md`. Share the code in `evaluate_reinforce_cartpole.py`.
 
+From the openai gym wiki we know that the environment counts as solved when the average reward is greater or equal to 195 for over 100 consecutuve trials.
+From the evaluation script i used the success rate is 1.0 when we allow the maximum number of steps the environment offers.
 
 ## Familiarization with a complete RL pipeline: Application to training a robotic arm
 
-In this section, you will use the Stable-Baselines3 package to train a robotic arm using RL. You'll get familiar with several widely-used tools for training, monitoring and sharing machine learning models.
-
-### Get familiar with Stable-Baselines3
-
 Stable-Baselines3 (SB3) is a high-level RL library that provides various algorithms and integrated tools to easily train and test reinforcement learning models.
 
-#### Installation
-
-```sh
-pip install stable-baselines3
-pip install "stable-baselines3[extra]"
-pip install moviepy
-```
-
-#### Usage
-
-Use the [Stable-Baselines3 documentation](https://stable-baselines3.readthedocs.io/en/master/) to implement the code to solve the CartPole environment with the Advantage Actor-Critic (A2C) algorithm.
-
-
 > 🛠 **To be handed in**
 > Store the code in `a2c_sb3_cartpole.py`. Unless otherwise stated, you'll work upon this file for the next sections.
 
-### Get familiar with Hugging Face Hub
-
-Hugging Face Hub is a platform for easy sharing and versioning of trained machine learning models. With Hugging Face Hub, you can quickly and easily share your models with others and make them usable through the API. For example, see the trained A2C agent for CartPole: https://huggingface.co/sb3/a2c-CartPole-v1. Hugging Face Hub provides an API to download and upload SB3 models.
-
-#### Installation of `huggingface_sb3`
-
-```sh
-pip install huggingface-sb3
-```
-
-#### Upload the model on the Hub
-
-Follow the [Hugging Face Hub documentation](https://huggingface.co/docs/hub/stable-baselines3) to upload the previously learned model to the Hub.
-
 > 🛠 **To be handed in**
 > Link the trained model in the `README.md` file.
 
-> 📝 **Note**
->  [RL-Zoo3](https://stable-baselines3.readthedocs.io/en/master/guide/rl_zoo.html) provides more advanced features to save hyperparameters, generate renderings and metrics. Feel free to try them.
-
-### Get familiar with Weights & Biases
-
-Weights & Biases (W&B) is a tool for machine learning experiment management. With W&B, you can track and compare your experiments, visualize your model training and performance.
-
-#### Installation
-
-You'll need to install `wand`.
-
-```shell
-pip install wandb 
-```
-
-Use the documentation of [Stable-Baselines3](https://stable-baselines3.readthedocs.io/en/master/) and [Weights & Biases](https://docs.wandb.ai/guides/integrations/stable-baselines-3) to track the CartPole training. Make the run public.
-
 🛠 Share the link of the wandb run in the `README.md` file.
 
-> ⚠️ **Warning**
-> Make sure to make the run public! If it is not possible (due to the restrictions on your account), you can create a WandB [report](https://docs.wandb.ai/guides/reports/create-a-report/), add all relevant graphs and any textual descriptions or explanations you find pertinent, then download a PDF file (landscape format) and upload it along with the code to GitLab. Make sure to arrange the plots in a way that makes them understandable in the PDF (e.g., one graph per row, correct axes, etc.). Specify which report corresponds to which experiment.
-
 wandb: https://wandb.ai/lennartecl-centrale-lyon/sb3?nw=nwuserlennartecl
 hugging: https://huggingface.co/lennartoe/Cartpole-v1/tree/main
 
@@ -183,32 +42,9 @@ hugging: https://huggingface.co/lennartoe/Cartpole-v1/tree/main
 
 [Panda-gym](https://github.com/qgallouedec/panda-gym) is a collection of environments for robotic simulation and control. It provides a range of challenges for training robotic agents in a simulated environment. In this section, you will get familiar with one of the environments provided by panda-gym, the `PandaReachJointsDense-v3`. The objective is to learn how to reach any point in 3D space by directly controlling the robot's articulations.
 
-#### Installation
-
-```shell
-pip install panda-gym==3.0.7
-```
-
-#### Train, track, and share
-
-Use the Stable-Baselines3 package to train A2C model on the `PandaReachJointsDense-v3` environment. 500k timesteps should be enough. Track the environment with Weights & Biases. Once the training is over, upload the trained model on the Hub.
-
 > 🛠 **To be handed in**
 > Share all the code in `a2c_sb3_panda_reach.py`. Share the link of the wandb run and the trained model in the `README.md` file.
 
 wandb: https://wandb.ai/lennartecl-centrale-lyon/pandasgym_sb3?nw=nwuserlennartecl
 hugging: https://huggingface.co/lennartoe/PandaReachJointsDense-v3/tree/main
 
-## Contribute
-
-This tutorial may contain errors, inaccuracies, typos or areas for improvement. Feel free to contribute to its improvement by opening an issue.
-
-## Author
-
-Quentin Gallouédec
-
-Updates by Bruno Machado, Léo Schneider, Emmanuel Dellandréa
-
-## License
-
-MIT
diff --git a/a2c_sb3_cartpole.py b/a2c_sb3_cartpole.py
index ee1087d..f322220 100644
--- a/a2c_sb3_cartpole.py
+++ b/a2c_sb3_cartpole.py
@@ -32,9 +32,6 @@ for i in range(1000):
     action, _state = model.predict(obs, deterministic=True)
     obs, reward, done, info = vec_env.step(action)
     vec_env.render("human")
-    # VecEnv resets automatically
-    # if done:
-    #   obs = vec_env.reset()
 
 run.finish()
 
diff --git a/a2c_sb3_panda_reach.py b/a2c_sb3_panda_reach.py
new file mode 100644
index 0000000..b33f309
--- /dev/null
+++ b/a2c_sb3_panda_reach.py
@@ -0,0 +1,48 @@
+import panda_gym
+import gymnasium as gym
+
+from stable_baselines3 import A2C
+import wandb
+from wandb.integration.sb3 import WandbCallback
+from huggingface_sb3 import package_to_hub
+
+
+# from documentation of wandb
+config = {
+    "policy_type": "MultiInputPolicy",
+    "total_timesteps": 50000,
+    "env_name": "PandaReachJointsDense-v3",
+}
+run = wandb.init(
+    project="pandasgym_sb3",
+    config=config,
+    sync_tensorboard=True,
+    monitor_gym=True,
+    save_code=True,
+)
+
+env = gym.make("PandaReachJointsDense-v3", render_mode="rgb_array")
+
+model = A2C("MultiInputPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}")
+#model = A2C("MlpPolicy", env, )
+model.learn(total_timesteps=500_000, callback=WandbCallback(gradient_save_freq=100,model_save_path=f"models/{run.id}",verbose=2,),)
+#model.learn(total_timesteps=10_000)
+vec_env = model.get_env()
+obs = vec_env.reset()
+for i in range(1000):
+    action, _state = model.predict(obs, deterministic=True)
+    obs, reward, done, info = vec_env.step(action)
+    vec_env.render("human")
+    # VecEnv resets automatically
+    # if done:
+    #   obs = vec_env.reset()
+
+run.finish()
+
+package_to_hub(model=model, 
+               model_name="PandaReachJointsDense-v3",
+               model_architecture="A2C",
+               env_id="PandaReachJointsDense-v3",
+               eval_env=env,
+               repo_id="lennartoe/PandaReachJointsDense-v3",
+               commit_message="First commit")
\ No newline at end of file
diff --git a/evaluate_reinforce_cartpole.py b/evaluate_reinforce_cartpole.py
new file mode 100644
index 0000000..1021baf
--- /dev/null
+++ b/evaluate_reinforce_cartpole.py
@@ -0,0 +1,41 @@
+import gymnasium as gym
+import torch
+from reinforce_cartpole import Policy
+
+def eval_policy(eval_length, policy, env):
+    # Reset the environment and get the initial observation
+    observation = env.reset()[0]
+    rewards = []
+
+    for step in range(eval_length):
+        # sample action from policy
+        action_probs = policy(torch.from_numpy(observation).float())
+        action = torch.distributions.Categorical(action_probs).sample()
+        observation, reward, terminated, truncated, info = env.step(action.numpy())
+        rewards.append(reward)
+        # visualize agent behavio
+        #env.render()
+        if terminated or truncated: 
+            break
+    return sum(rewards)
+# Create the environment
+env = gym.make("CartPole-v1")
+
+policy = Policy()
+# load learned policy
+policy.load_state_dict(torch.load('reinforce_cartpole.pth', weights_only=True))
+policy.eval()
+
+eval_length = env.spec.max_episode_steps
+num_evals = 100
+number_of_solves = 0
+for eval in range(num_evals):
+    sum_reward = eval_policy(eval_length, policy, env)
+    print(f"Average reward: {sum_reward}")
+    if sum_reward >= 195:
+        number_of_solves += 1
+    
+success_rate = number_of_solves / num_evals
+print(f"Success rate: {success_rate}")
+
+env.close()
diff --git a/evaluate_reinforce_cartpole.py.py b/evaluate_reinforce_cartpole.py.py
deleted file mode 100644
index 01479f7..0000000
--- a/evaluate_reinforce_cartpole.py.py
+++ /dev/null
@@ -1,32 +0,0 @@
-import gymnasium as gym
-import torch
-from reinforce_cartpole import Policy
-# Create the environment
-env = gym.make("CartPole-v1", render_mode="human")
-
-# Reset the environment and get the initial observation
-observation = env.reset()[0]
-
-policy = Policy()
-# load learned policy
-policy.load_state_dict(torch.load('policy.pth', weights_only=True))
-policy.eval()
-
-for _ in range(200):
-    # sample action from policy
-    print(observation)
-    print(torch.from_numpy(observation).float())
-    action_probs = policy(torch.from_numpy(observation).float())
-    action = torch.distributions.Categorical(action_probs).sample()
-    # Apply the action to the environment
-    # Returns next observation, reward, done signal (indicating
-    # if the episode has ended), and an additional info dictionary
-    observation, reward, terminated, truncated, info = env.step(action.numpy())
-    # Render the environment to visualize the agent's behavior
-    env.render()
-    print(terminated or truncated)
-    if terminated or truncated: 
-        # Terminated before max step
-        break
-
-env.close()
diff --git a/policy.pth b/policy.pth
deleted file mode 100644
index 970f5c065a6275f6d9a5038936d3e7bff9adc6ac..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 5512
zcmWIWW@cev;NW1u0Q?NX3<dc)naP#<DTyVCdIi}zZch9RQK+DSDLFYmCnq(zBr`v+
zn9IK?CABCu#U(SjgsYH2GpLYJBZ7;8fgvr~P_H~SGd-iEkSVz&zbH9FFTS)SGpCp<
zz9==RG&3h9z9coTIKL>q%!sRySwka&oq>U&xFo+QF+H`A1)_z^Ehj&*Bp4#dRmd8_
z$iTp0P{<a+%)r2qTu@rb?#sYbQpgeL&EVbO&C*uL>CNKJ*jC8ZnE`SNcS#{nP$6#w
z$gd#xBxNQR7xHB=YiMM!WIzlvEacB%_iiXD6absiS}5qv*jgwA(Jx$5C=ygC3iXc>
z%s*l%CKwfpXR!M+FqRZbfQ(@BZfGl%^lkt-LkePybV;F1P@yc;8Af_wXUL%#Vq7Sn
z!S2meQm6nnqqR^G<P0T<e&v!vm7qe^(%M3`2u=nDhWOmnl0;Ap6{>^1si6@8ipxUH
z+CnXGD3+xb6=&w>6>57k7K4K-E!nV8r?yZx11!SAz`#(Hnp0Y+S6iqL5dwvMN>P46
zerZXeL2aQSL=+UtX~{-~Mzw{;5D}2N;{3Fd+{B7PliEU4uxM#<l3su}J4gKS4ZdPb
z3=ANQJCzAA=)h80Qe{bMJ}CXTIccCLGt@+f#LLMnDalC%+1G0OSdxu_fdPbZ+sDgb
z53>&xX8HzhPS!+QhGZ{{X<6dAXQ!s?zIKU%eY1L1_h0l7+H322*tX(f_ujot^KG4O
zUfw5HSZin3T)6++oISQYlU?^Y?|W*unk&K1?yHa8lbo)7P2z(4b_s~>tH1uxwmV?v
zK9Bk<cGGU?+N~~LYNvd-#jaR&^<J%YJ@$`V8SFRh&e>OVsL4iKmdB3$45xj?!P$0i
z7QVGL*56|<%x7gUci(xRTnPLAualJb<vzZ>FYD(`n|NNoy)`Ee*fRdPVcjgUZ_mRE
zO?GRJH16BF^7}qr;Ul&`GeY;8#H`+HvGa=EggsO3%uTQEyTdubt}|)LJ}wPbyWUAz
zd!sHa-1o1z$L`H?rhS4tZtVKOC%XTlgV}yN6+s)%xx4p;y^picFP^aXHDlPmi)Yu`
zndo)w%Vj)ZcW>#lJ+t1uw>$Z$#oq1qGdltKH~UJhIrqzJEA3|t%-Gk?$ZoHkaKTQ(
zU&>ylSj0Ym)s}q=9Qf=OtrFP(L59mtM`!UquKc5WKVEvi@A<*<eFfgr?1VeM?wfAZ
zu-|F6$o{q|yZ2owR<^fkPPQ+4FJ<pl;%#@H|JA<uzk0Szg}&C_`xn@**r;YVVa9yB
zzo|3#>HJ^1H>9z}Zpo)*b}~yJ*&SZpZD*e&y07Ngt$nRp{Wj&-+4hLv-fH_?ecHbN
zO<VS@`@v+tnCHHo>G>nJGZ*sMn;x2FSN%uX{@B|?cISKf_eE=6+Lv(n|2|zsm;IXm
zCH7mHi`!MkH0|vZTD^})=H5PuJpX;pVf_1ZETs1Lsmtt*RGw%zQJrOf^rrOvS($SC
zuT4H|C*Y^JZ+4ift-;F8cKj3f+NQB9*=93U+WN&WvwI~z$@ZO0%f9lSlXkaaTK6@q
zS!SEF?w#FfaW?x|ed+tIH}~4j`K!6_{)Hmjm)>P|jX!4D)vF)g=j3_TI(Y6Y+qFvT
z_I)h3+gH3-WB-f|+w3eJthcM#a%k_fgBA8%wefZ$tJL<UMR)JJ{Q9yT_sQF~=^EQ?
zdk$Ibb5rHp@8|Y>@2g{s_Sf!i-S?^F+dg-bp1tzgEA4uhaqhPhklrUTuhnipTb*sg
zTNnGV-yZf`KMC5~Ff6cjOA)ZoT|3uaKTF1*Z`~!^<j?Q+-h487|I@pQ`&k%L_C9iW
zYpY@Q)NW>#^1e8+m$t0aAJ|zwxMufX`OrT0U#IuhvM=8A|JpUXj}sr*J-zh6?)#}_
zdzTwD?5|$Rx!?E4-F=a^68nCaY}vPDN9g{Ei*)u^&6C})(0F3+AKt!wf&~Tpc=oN|
zd*#(byRsD@?D$l=?994K_N_g7&-Q)j7Q4M1YI{!Fp5FVN|Do-kxrgk`R<!N2OF6SI
z@=oLat1A@t+qLhvEeP=6cl+z_eG5wV+pV@2wJTzMvTvnJzP+!2roAZ-lf9<ghJ6{|
zCfiKVZ?WtAV7hP9i%0v+LL&D^BzD=l>h9QQu|s(O&B(0%(N!Dw-7q}5FRISLuJ1$8
zes;^}Hfe8f?a{tBYag4i*nYVi9{XP1c(~8|jo5zWZ!h+3Y7ySw{HV`XJY(^`RX&?`
z_stEq+xMVs-@ea(_o+V=-5=d3yPvo9u&rA%|Gv4#;`{RL6z$%(E!lVO#jd?~M1<}4
z-&k&!aOmJZCf2=s`y%!2<+QZzMVX`QCKdDSpCV;rzq4w$ox{5od#B#ty)W3eb${TO
zpSFf^r|gQF-q?LoSZb#jI>GM#v|V=ATpRbv%$BoD`o&~d$iT2~O~hr}H-EMEJBIM=
z3uP7C*Hd)F?%9uw{jo;T`{ZNW_XUTB?+c!KWnW+y%YLg#TlV_A-n38X;{5%7i}vn&
zb1H0at-$hq_ZQ8vFO9lrdoc9$p7RrR_uAL&-J33<zi;k5_I;0@%(Z)Z=f}Q^Kqb2j
zW+l6_Mmf8QoEz=c#o6u8eh|0YeZbA`1Y@_o7VA?x=P%0p<SuB~t=z|I_vgU=eUV&t
zdyS$s?91$T+A~<(va{pQwXOQgU>~p|%l=>HMZ34w-1`DI1>61eRkDw*^|EK$)Ut1%
z_nf`FZwmGu6_~y6bM^$=ZOr`p^o0`b101*7IR>BF_aP?3PO`ykzx0Mu+a;k-_nx^}
zw2$eZ^nUL5>HD~1nC#Xj?cS%fZ}Yx2;*oo~CZ61PzxDFIw|eGwtRj2&?ViZD|8lsX
ztzSpv{?<!A`-Nh6+a8<oaqo=lj&>aT*X=tVxN=|R!q$BU#rW)e6?g2jvuD_sAimq~
zn}h1U>ef?s)!dWz=}QIL=_MuE$H!f>v$r^~ulZcPol&OI{^w^D_h+49wNCu?-&TUR
zfA5=HC+sYhxcA?i!Ms20&4qoQf7<ump0~|b;pi#5osIeSx`|ut9<DF2+mwEJ@5_?<
zeLt9g*#^r5?_DIpYX9c%V*7?Wf9!tVUT&xFUS}7$ykWmq>xz9_1Xt|i_|CoW*wGLB
zLhTv$pUz_1Yx>*L-k5WPoxPs6{kGj)`%D@e>@si7u@f}CYiFP8yf5R#j6Ep<C+*bL
zAKG;v<+YP|zQAVDC&vA`GJSSu*Bst+xAD&2`27>?#MHOhMZGSv+xC3@o`b24`vhyU
z>;lri*j~x_ZOgNB_C5=ZdA8qNH}B0XxoUUvf7U+kLL<A64>s>RE*xn0A>huw*S)`O
zTplXyPi1r5AI}}RKfzIae}(ouyL0ap?7zSIV|T#)h~2++pZBOX3D{3km9;<D%x`~m
z?tj}es!{tGPjlL9&*HK_JE3}SYs=g{Vek5FyEYZ?4HC}SXKZ}fE~v=WeoFR3yRV8m
z`(#%=w9^z{X`drpX~%b&eP6kNj-B?=`TM$<*zImOcJ8%vcd}Qg^|POGaIPJw<<WIP
zJ(B^{s^nl`z}NC%MQM2$x;Zg2L><{S(@x#jYTpA{4f|#H)b@Kx3ha-$y4OagDZy_0
ziCmjE$vpN!bM5TBL}m6_t!dl4q<G)nqbKtADGGMk^<^KkQwq6dYrj={zu$!*yOO{C
z`!bhbvAyB2(C&Jn@xIojnRXvizSzVqxv(!jf_MKf{TKUYHMj1|yvwy;OYDU0S$4gB
zQJtUm{`kYW-`q@h-;^V(?7}nU_isCKX5aFh>V4L=;``M*@9+C}yJ6qP15@^Wa6N1H
zK6d`TH|@>4|FhrOce#sk|C`28ySa7Vc5-@L_P2Mf-Ix79+fK#S!A>M*-aaR8BfEx#
zi*_3o0{42k7VSO7DZIbrnWCLB|FeD19*Wt<>h|wrt=zKj{@1m3m$X{;&2rshcfH(k
zUn>8CeMZtR?Tq?d_kA=K*lWC~&1Ty~NxONEd+ZAP73~(q?%Q``?{8anb6Go$qgr-u
zogH@17B$%I-XyzEaBAPaO&+dxv*Ww%LKWoqe>RM_Yg%Gr_m%mY-R*0tc9xpc_OAUq
zYhRbPxLw7)Py4<XP2Cq$HG6Mx(9V6V5yks*GoI`_mVdzZ{BA9~bvhgNB_-{!d$lOp
zZu#p~yZ7A>v@=?B%XZNWYrAOuV|EVWZ2R9zuH4(0yx*2N{KUSs1q}Oz&x-B$Uc7Z*
zN|pD%ox4@-&-HZKhE*N26+hN)6S3~uKI4}OcEt_Nc3#PB`&(StZ9(bB>Dvh=1||ju
z5GFqT7=hD|AiKSt+Iz*louyHBi|6s#wOB3Pqb0Dy_S8Aqeb2OAY{i8q>?_qVu>0uG
zYP)FF65GU3N!tY#QG5EH{oTWQVX|G*y?wT6=0?^k3tIL$_Dr&y|MT14-aU+Vf4yvM
zT|6(^{XFZp@3ipyy*F>L?z{H)pl#2-e%sGmQ}(3aZ`)`0Z0o*z|GwC1?(x|hXnn!9
zCFYIYf&yjRM{NOnSq;-{!^*$yJrw0?>)Z6qcIjVFyR4+gd(EruZ2CL0ZMImhwY@5T
zb?+ASh<#^5xc2{btFz0C548OiD86sqExEmoj~3W&wVq>lAXLGQA#>|qK0%|sEaxBE
z2K(){TN`L-dqBW?-v+sTduOhzx2=7kZ@2h`fSsC@tL@3=C3eTFJ9fuiUuf65)PJv`
z$jaS~*E#l;T0Y&&%BgO5nmKsiyZx(dI2;#S7gwybxp9xpj&YiRz0nS~Jq67^w)vYE
z*$63~-^2eydY@s7qfJJY!@kH5Q};O^i{9&K|H9U6-}ya%RU+(G_T8~Pu=a=Tkzfwn
z+zs38_J?xY>e}zJoi~li?ttdZy*D}o?4tf#?-fg~va```+t+yN?VefdllCQ^{lAA(
zvcPWDV!3@%J3DOIz6R};d2PNYHEY`5OPi9c<yY$19C=V?J6%N4?lALQyGI&M`?}Me
z_g#5^Xs^XHqdkv8`u9GL)!kP%-EPm0Qucj}pPBcyTZQc9TyMOOSKedK>pQ>qeqhnC
zxqrIY?%uXu>vpe&c0t`*`|K^PY^&}z?=>@<Vk^6u*Vf_cGMi^Fc=jE9+iyGN>zlp(
zJxzA6mUiruc4)G@7-hBR7{B~J*T#waS_&HMQal#e{0rV_6T58kUiRmw_N-qoWV`FM
zrrr0i6YYdHuh~~`dwjQ2RrEd`%f<U5<!0?IWpmi8;Fh^}--PwHJ_~2tXf)PZA1;a5
zmvi!`t<jb&8+Yz>+q-vv?M=L{XuI=xm>pZ#Cp)1jIs1Yao!rCxt;%+Gz)IWmZ*T4m
z-qN>EarfeVyRDAfdZp&r9&^aBjp%r?@5v3@Jv^n}cARdrY}=(T?yX*EzR%Nj!oKWd
zCi_l*zqGF`RmxWC$Hjg7T<+{S?hw3Z3NxeaFS!f0a|E6C?Nej7%WmSgnJGGFZ%KmC
zzV2HU`;MLG+JE)djok{>6ZXEc*4t|<XJ^-ZQPb}KJgI$wd$!qiPv2{|?k}I+mdl6s
zZhE@Tw*R5-zS+G-dnYDcvQb*$xGy`cVqex94ci4_3j4N*o!_(5P1Ww>>Dqni;tqTH
z-h}LHIkw5p@2!Jv#LqT6tvieN7(dv$_r(qmJHNOEd;L<QY>hYZ?lbDi+&ejV#$Jzr
z=X+(Mm)i)LzT0cV-?lHp^o8xI`Sa}}l~?U~czM>|dxt03#)QqVJN9Yv-un4p!R3cE
z=fo}sP&b~1l={aQTz(kG`|siVYqb|t4*fk@7YnV%@YG{`3^uTO3_7yl=43&37uuK$
zWb?&&>7md95Rc8W48|~<^}wC?;?$zd#GK5kM9|<<d}fN9lRl{qF)%eVGd4A}G&L|W
zF|;tWG%>U^H#9deFb1)WEQ}4!EG^6p%#A^=oFDDBM1z5W0fYm*89@$)kDMWocJPD7
zUr;dGI2yWU<e>`z6pd+23@|UlheyzLBR4;IQFLEJ=mv+E0lHS?ng>;D05{xXuvSBK
zt;mTERqF&^RINtnT9MNks@4mFs9KHDwIU}JJ`}%72*I^NlO?)d<mi({(K`vD7kk1A
z@MdGvfohRs)`hYdz>Pi-2b60-^gGaqBm)D3g)oQ*%8hKGF){`<Q3xL-9pKFh8lqz0
MVBlZ?sfVZq09epwKL7v#

diff --git a/reinforce_cartpole.pth b/reinforce_cartpole.pth
new file mode 100644
index 0000000000000000000000000000000000000000..8f0468d676b0a0ab9441b9269ec62d27ca79ea83
GIT binary patch
literal 5800
zcmWIWW@cev;NW1u0MZQX3`MD#d1?7Y$*J+liA5y^`8lcjDTyVCdIi}zZcfY$QBhG1
zOv%alIXS7xC7Jno#a#YHDXB%NDK44GC0vCJnn8t(8WCI!3=C<>hI-|xnduoNg-ppM
z`9;YYdhw+tnK{K=@kOagrI|S?@g=Ew#rZ|?Wky_u%o-XI><kPH#U=SgiRr0@ED$YR
zZaMjhCBYCuu0qxbMg|53gF?0lW(Ed^<bu*dc3%djl0uF^ZwBuMZ<e+~PHz@(#<oJP
z&J2)KxJwFof(m&fKz;?eCn+<rxR5V{Swka(B?DrZVIhA8yLUrLp#a#7)<QvV#@0e1
zh<@RcLXn_CQK)~6VEz$9F~O)%JcHerfw81e0%Qb}cSBpDq;~_z8B!2qq)Q5Af(m7!
z&M?vgJ3|h|5aUAm40dm(l0pTr8LfqiAZI8+^edMXsst6Pmev-kMQ}1OFvRDkmL!5=
zrBEI0O%06*P+S&j))s1kL$NHis5mn}uTa~Yu^1dwX~~9#I<<wm8DJ3>1_p+r)SS{n
zz1l*3h!803Q;PBn@=Hq!4QdMwA)=sAPD?f_G^#B$hKPXF73Zgw<R(@Wn$#AWf<;S<
zlk@_-**W5mZ}1gkVqgGa+^I~4K>=GTOR6kM%?G6=HzyfTGK<0maxzOwa#BI|wc0+G
zWMg1p0AbwrNirB?vkw%0`UY-J`gkqF;s~Dq8|)H)OYOV9w$J{vC%^px;rhK7U(dDe
z?E7J>!|`qJ#S_!*t_e5Uy}9DO@2H)`{v^Z8cKm(zd+#stw7bH)$S&v2PP-<temmK#
zmisS%GTXOv`CZ!shWl-+tQGgS^GDeIXsz1Ey^hoV!=`w<2b$OIoEL`LRqC1U`#hiB
zZff-@JKlPh{oAY$@B99L(LTKgy!JCa-|zGCp1UuuRK{+@z2kOU=SkaXR!i7f&C}R_
zf=g-d3G=3X9kRmq51c31CHA`7MaVepbNuXMmp(aapCWh6KJn|V`(iF{*vBedy?@jC
z?R!`MP~9V4b->PBlY770U61`{6&(AgnsV&#h~2i2;oho!HK}~|0qZ8(MV#%nJJ0pa
z?yhI}-ljeJ`?dwP?7x=Hu>aWFH+zF*zwG-`ICXzlxy$~)*Ou(N9o4ySUX=0vlPM1S
zWp8Wm`*SMIZr#P^ePL?1_G!P_Xvd+|w6{}_)7~=6)=qT2_&zQR0Xyq_Tl;f)YwU!5
z7VTRRv}x~yyHocsJTJJ9%l-JC*!bRkD`!R8J!pDucm52YopnQ#t(Um;zEX}DyS_bn
z`|i9?w!74{&~C=ldfR`8=GZ!Hm$18%xn$ot^Zb3`^)K!2n|JM-F!|iRCYjs&K8LyO
zG4PwUFQUTFezLr|{ngh!wniUL+c7LMu{&0M$8MX-d)wA<WjjI7Zu>VJ+V&r(H`{4v
zZ?WzB(7W&PLX-X9x&`-5zZkIZ*2>BD$EVrZ&x~T<cYObXeL}nL@9C;vyszHs$nG1>
zEBD=tXW#$mj_Q6>&!u}FubgPtEgxhrp66&E(#ULkps#DMrTkpGm1(SY*1T4B`fOr$
z4N|f8%a=*me@H9aJ56lu-qj}B_NS!S?AbT@?YnuU->y6I`regxX?D}9>g^Wjvh35i
za@f{r+6z1HH1mBs*e6(J{%YQD^u=S}xf1dHZW1Ck{`YR%W&Ad>+ab`fZ_STec86Kt
z*)91IXVZMrZeOHh)jn&Txc$24r1qcLYqHnMxWO*?k*R&+5()dp#2I#XZ++XByR~-z
z!D$ZrKc2JRTVcy=e?<DdU0!O!zM@1&TZTo+`!a0g_itF2y??Xi9J@^u%k1WTc)s`L
zjFWa8=NtEJ2%BVQ5#VjVuc&$N;qNwfZLE9j7){>TE;nS~Z<JeP_imc*KEqW9_HADG
zZJ)Wy)_sTT1nf#fUfVs|D{WurCuH*~$a7CZ;N1N)953!;p0~?x{n<Ts7Q)-?i`g&R
zy_nx&)4JT;?tM$w{)c7n_KA6w*rk6+*c0<9z`p)EzrAhNZ9Bi5W&0A=owVEf+s01o
z;xW5V?|kg)M8fTsRejnwKg7(=t?tr3)9wHFsc-Dqr*ezIzW2}(TW517`&IJO?G$b}
z?R&Fm`932JDf_)mx9ncM-MG7JSFr7q(wqCfEH2%*NZrHs;FE}bn!YOgrhW~vGx|Bx
zZncMj?bep{_96ca?e{wb?dz^yw6A_a-hQ^J0{c5XTKDZgu*=Ts;tD&@mW+LlzNh!)
zo|wDuOzccMvu#`Vp0w()JvFC!-_`Ih`%(kA_b+=`XvYvPy8rAI?fq8@9ql%2Htw6>
z@O#hGF!%kM@tXU2RzKYLx#G#b%ZqdOyPo{IukP=%eeaI0w5u%aw|_hJn4S1}6}ts9
z*4n*#XtBR<+0lI((ml34#n$_DC#<wR(X(^k;*}fiW|mIb7kFa*zV_AY?GCvm+7>%K
z+xIb}-~OYsu)W`WKiiC`_I>t|ANE!;FzpxZ|Gux+;Jn?J!{+<yzr5HzwfpwoM=LD$
zNlxIkHw$&JV=rvoXJN!+`zE?#|G7A^{Zj3J?3zCW?t7N}c<=5PX}jZ$Sof71-ru`q
z+l_tf4=wha&e*%}$IU1E9*gwtKi=-R-`6MGHtqU;ThTkR_CJFK?7wX5x9h9BvTvJT
z%f5uA^81sfytC1Mch{!I`MDj_v`Ks0t~{{~TPm^7`dq6$i?f)0{N`J`kFM{uJrq{C
zcZv8+I|K1~c9|~1`#&73+`ph|(LNoWdb|6!%WXG5=d|A>(PWp=QNDNk_HDK^1=Q?r
zUs2fi|LY7pao@LgehTyVSFgUhPjBJ^yS9RZcEVRT?lZe5wtrpQpMA4eAF?yJ&A$Kq
zwn^4MVpiH+xZh|uv-Q!w8$U1YJI-2QYx`2b-goC>`-30<+RgiP!;YsU($?XmrG1eH
zll=`nFT3eh4g2`@j@iji`DVu`AZqW=9&fwQ#ovD4c}e>=<{A6)PqgouxYyY3)^Ry|
z{(hPLnt_Y=CU17zx5v$8-=u;c`{b_2+wHwAwm;9sW#3P0-hDg8*6lsJeA?bdrrCA{
z?7#NjZr^1iy;#nEj_dM$Ifs4r?an{3ue<2FZ6VK#eWrcy_OVs5?RT{9viqvQVt-zA
zqV07>$9-|kr}rJW9JK$>%O`dUD;W1nFL_}%m3_Hg;mMADCRft!n&SNTF_-Y~>s2wg
zo3C)qF1%mcexKsRy?#!w_NunbvJ+JB-v8v%g?-V?_O=WhxAxwtJ+)7VCwgCuaqPay
z)hc%KEO+;DvU}_c-}ZWsO4#FlDI90^=|5!J@44~lz84>V?VEJ${615=ySsT9R@kX`
zvsuRr#_hk-VX$BC$5C5niwm~V)6ebuuC#OCslU>Dc6C3nohES9PBi4)zS~WnHX8)Q
z_ddUIaR2Tw<Nb}h18m<+YO`xFwX-*Gva&zi=)Es5PsT2hQP}>LXtaIoQ!g9-GspMt
zcow>U!;<N?Mdw@g#kTLZTfP5)-68uqwyS^c+SjoD<6i4!%=VL(#n@e}G2U0GnqgDN
z{>sj<N_iis<q==~MvDQ|s^nl`z}NDy#M$yNbaOH#wdDb6j6AMgVRyH^#xC|<mrc#B
zQ~SD_TkLLnf7r(y!oS~GdB;ABnJ@PyZQQ$8i0{zewCMGAoDyI6cF0Y)i+0~?yCH7N
zo)_QN+8qv>Y;$qZO*`Ii7wu-b&Dgt0?xD5Jv1@h$3w8IpF6y*R{KdI{L-pLfYd+4~
z`+aMloz{t3yH#}?_Z8j<w+pkGXZP&~tNkuXPrI#`)$H5_O6;t@?cUdr++!DQ@zVD5
z>acx+ie7tT*!u0RxqY(pQ9o%n|4*x3`KA?iP4$260{30qZE6}}$6MaNZ;^?NoztwD
zc3u-KY!{g}+jOqrw>!Ujt6k9aGCS^uI=lJvw(aZL&A#8X@$8<eo<FwhH0|uF6qIc5
z#7?twZQf(2_9E4;wEgYAMO&}#3x9pY?w~KDy>G0IUBpVBy%meE*?qjwWB1-}@xE7g
zkL}yO;kO-o^Lo4Y6K(fef6m)iusF{yK4$K|<frRxd%alg#WEFbqb^F?Tv-smZ|<q<
zcDI$6?b{Y|%yw_*>V2ki75nU7-naX@w%@Mh*Q<RGjWYMG^?hc?qOxPJ@}{}AC$+xW
zdA11I`Bl!a)4LzFFI45zzSLfxeR^LT_I2NQvbXl0f*t1*$$igwMD_>&*|(4X<3zi}
zOMUzLj>_8k`p&dF#HnVdzp%&7$F6m+`jO|m?}_r;|9G{=E+X%b-K*a$wwLT@?wcy|
z*;X^A4V?b+Zkf+yU}9hZVe-<S5!vZ4U{Tz@cX?fVtLII!{r{0`k1}J>uGY^EdmR3S
z*ebaFw%O;}Y8U6AX1BDAYwy;J4z?w~>uv9+9k96~r)jsiC*1Dxt(klGN&T^P>dCfC
zsVcWSndxnFHTd72L;H5_`NWWB$I^LnkMooy+rn%%+t>Z~Y~J2TvpXVKzVD#l6x-jn
zH}@%ivEK9L%(i{%`nI<1`it%Q|Euhqd1Kbzr;ldr4c?c!&n)JTdG?krn@-J{w&7BJ
zd(YlEx0iVZo81i8?tMS3XWMP(`M*bmX@OnK>iM>77qZ%V#pT;g5}v$gTF$z?+Lsya
zCTTYAtGcUZ8+L5|zCM+CdwZtu-YdFw?w$pI{P*23ir(9!uy@Z%$r{^V6L0KWz;}I*
z=%LfQ7ap2xYjOYRo)0^8Z1=3Sv70a0Vt2FP_g;}ty|xN@M{GhG#P*&op0-cxOP$S(
zj~O;IthjBG_TIGR`>eB1XraZPO!*qS({FnBo?5iYmh)WJUJU_{eNmTo*otqKv$HPZ
z-Y2v8g^l|1bi3o)26hYe`fWZ3`r8G$$L<s6RN2R^czh2hf8)M=YGV7^cwFrc7H+a_
zZNFl@V`sHZ@Lo1M|Koytw>Q1DP2Co`?|=S`z1uRI_p((kwOuF7v{&(HqKy`Vt=$Qo
zemlcElkDc#XzXoGw%x1!c9orK=COVI(zoqPnp0-`Z%x+T<9FEihEH3yM}pVX?se#6
z8;%)&_Pl#qyLZO&>wEeCnb{VLe6_h5vDntMPJ7>ih{t>Lrxe*-+MQur{>9E_BKN|*
z%LO|3JzX2LFS6Rpj{B+6zE5-d_nA7_?%O{7jLjpBNw)6EqIR3#RqV^s5VB>hD6(y<
z(c5$SdBa{MuIGDeYk2JJ&u`lET4K4Kupj%rqt`y!<#K4+UNUOlx6t0%F6wWHP2kL^
zJ=J>@?Fv1Q+3w(X+`Z%k<334VNxKZzL$>SQOYf6v?6z~}*uC!(^9F0T`C|KG>l*eb
zNzJsC%<kSNl*qDgskfQk!55wTG^QupMSNnmS=0BzF1K;EZAcHV?aLYRd%Yqy+cvTZ
z*or-u-Rq<7YyG!jt=*-tn|s%5^6u?gJb$md<(s`?a##1hVT;_y&z-UFL9D;sVhi`Z
z3qQQGd2(9W<_dqe-L812J+6~2ZGP|Ewr5JA(Y{i@6}Imcw(YTNoU`xwqboN2D~$J9
z=ezDra8tAE+atNRa*D^k)3ey@)Lf_AesZ6-@9dZBd*yUx_Q~E!+^5iDV^<bD*;e$O
z@?O7Pd3FM|C#_e;-QDYx$+qwIZ==0B0StSiS1j7IdX0hIadobJf8JEvd2lbc=2Ba2
zyIt91AH#8*y;H3x?>VA5W8cM8n|)uM*6(v<YS=S_m2t0+_lA8NC2!i*3qG(j&(Pfa
zVR?{^M_;?m^TY&OQ2jUa`Ac>N&@ciEx%Hnh+2xNLi_zXa@5^jJ?uee3RRgVt@zldo
z3<lWhVdyA{o0ASn6|pfF$mWal(nFyIFdmyV8Pu`atOp)^C{8WPOw7rwN(2q0#b>6t
zIVlt45Cc;)Gh<UjOH%_A6GIC_OA|v&b3-#r19J;g3v)|LQ%f@ob3+Sr6Ob!SqwH!l
z7#J8pIKZ0`<Y4&7AM$98G-&(?1>+e9L^lX|7)A!gkTfRHKoiD*B7CR^-6Z5@l_ZKu
z*RY!e4o3rYV~}e{RAU0Tu{sEBj3K%)$cY@)m<znPjWI$u204wQ8WSLd+ZbbXV~`V^
z6iV3Kz-|mQd83<yoFp_+%n1_4>KyEeHNcyVO$Vx7j#(GVVgNUaK^#!71krxrG{?a3
lLm0#Zl@@HEF+>J8F$f<tpc~-L3L3&=;9%fj0I7$l1ppz8<-7m@

literal 0
HcmV?d00001

diff --git a/reinforce_cartpole.py b/reinforce_cartpole.py
index 50be37c..dc01fd7 100644
--- a/reinforce_cartpole.py
+++ b/reinforce_cartpole.py
@@ -1,13 +1,16 @@
 import gymnasium as gym
 import torch
 import numpy as np
+import matplotlib.pyplot as plt
+
+DROPOUT_RATE = 0.5
 
 class Policy(torch.nn.Module):
     def __init__(self, input_size=4, output_size=2):
         super(Policy, self).__init__()
         self.fc1 = torch.nn.Linear(input_size, 128)
         self.relu = torch.nn.ReLU()
-        self.dropout = torch.nn.Dropout(0.2)
+        self.dropout = torch.nn.Dropout(DROPOUT_RATE)
         self.fc2 = torch.nn.Linear(128, output_size)
         self.softmax = torch.nn.Softmax(dim=0)
     
@@ -38,60 +41,39 @@ def main():
 
     max_steps = env.spec.max_episode_steps
 
-    for _ in range(epochs):
-        print(_)
-        # Reset the environment
+
+    for ep in range(epochs):
+        print(ep)
         observation = env.reset()[0]
-        # Reset buffer
         # rewards = torch.zeros(max_steps)
         # log_probs = torch.zeros(max_steps)
-        rewards = []
-        log_probs = []
+        rewards = torch.zeros(max_steps)
+        log_probs = torch.zeros(max_steps)
         for step in range(max_steps):
-            # Select a random action from the action space
-            #print(observation)
+
             action_probs = policy(torch.from_numpy(observation).float())
 
-            # Sample an action from the action probabilities
             action = torch.distributions.Categorical(action_probs).sample()
-            #print("Action")
-            #print(action)
-            # Apply the action to the environment
             observation, reward, terminated, truncated, info = env.step(action.numpy())
-            #print(observation)
-            # env.render()
-            # does this come before adding to the rewards or after
-            
-            # rewards[step] = reward
-            # log_probs[step] = torch.log(action_probs[action])
-            rewards.append(torch.tensor(reward))
-            log_probs.append(torch.log(action_probs[action]))
+            rewards[step] = reward
+            log_probs[step] = torch.log(action_probs[action])
 
             if terminated or truncated:
                 break
-
-        # apply gamma
-        # transform rewards and log_probs into tensors
-        rewards = torch.stack(rewards)
-        log_probs = torch.stack(log_probs)
-        rewards_length = len(rewards)
-        rewards_tensor = torch.zeros(rewards_length, rewards_length)
-        for i in range(rewards_length):
-            for j in range(rewards_length-i):
-                rewards_tensor[i,j] = rewards[i+j]
-        #print(rewards_tensor)
-        for i in range(rewards_length):
-            for j in range(rewards_length):
-                rewards_tensor[i,j] = rewards_tensor[i,j] * np.pow(gamma,j)
-        #print(rewards_tensor)
-        normalized_rewards = torch.sum(rewards_tensor, dim=1) 
-        #print(normalized_rewards)
-        normalized_rewards = normalized_rewards- torch.mean(normalized_rewards)
-        normalized_rewards /= torch.std(normalized_rewards)
         
-
-        loss = -torch.sum(log_probs * normalized_rewards)
-        total_reward.append(sum(rewards))
+        # calculate discounted rewards in reverse
+        R = 0
+        returns = []
+        for r in reversed(rewards[:step+1]):
+            R = r + gamma * R
+            returns.insert(0, R)
+        returns_tensor = torch.tensor(returns)
+        eps = 1e-10
+        # normalize the returns
+        normalized_rewards = (returns_tensor - returns_tensor.mean()) / (returns_tensor.std() + eps)
+        log_probs_truncated = log_probs[:step+1]
+        loss = torch.sum(-log_probs_truncated * normalized_rewards)
+        total_reward.append(sum(rewards[:step+1]))
         # optimize
         optimizer.zero_grad()
         loss.backward()
@@ -101,19 +83,18 @@ def main():
         #env.render()
 
     # save the model weights
-    torch.save(policy.state_dict(), "policy.pth")
-
+    torch.save(policy.state_dict(), "reinforce_cartpole.pth")
 
-    print(total_reward)
-    print(total_loss)
     env.close()
 
     # plot the rewards and the loss side by side
-    import matplotlib.pyplot as plt
     fig, ax = plt.subplots(1,2)
     ax[0].plot(total_reward)
     ax[1].plot(total_loss)
-    plt.show()
+    ax[0].set_title("Rewards")
+    ax[1].set_title("Loss")
+    plt.savefig(f"reinforce_cartpole_dr_{DROPOUT_RATE}.png")
+
 
 
 
-- 
GitLab