From 36eabce609ea2475cbae87f93a2df8bb9932013d Mon Sep 17 00:00:00 2001
From: number_cruncher <lennart.oestreich@stud.tu-darmstadt.de>
Date: Fri, 14 Mar 2025 15:15:07 +0100
Subject: [PATCH] init

---
 a2c_sb3_cartpole.py               |  47 ++++++++++++
 evaluate_reinforce_cartpole.py.py |  32 ++++++++
 policy.pth                        | Bin 0 -> 5512 bytes
 reinforce_cartpole.py             | 121 ++++++++++++++++++++++++++++++
 4 files changed, 200 insertions(+)
 create mode 100644 a2c_sb3_cartpole.py
 create mode 100644 evaluate_reinforce_cartpole.py.py
 create mode 100644 policy.pth
 create mode 100644 reinforce_cartpole.py

diff --git a/a2c_sb3_cartpole.py b/a2c_sb3_cartpole.py
new file mode 100644
index 0000000..ee1087d
--- /dev/null
+++ b/a2c_sb3_cartpole.py
@@ -0,0 +1,47 @@
+import gymnasium as gym
+
+from stable_baselines3 import A2C
+import wandb
+from wandb.integration.sb3 import WandbCallback
+from huggingface_sb3 import package_to_hub
+
+
+# from documentation of wandb
+config = {
+    "policy_type": "MlpPolicy",
+    "total_timesteps": 25000,
+    "env_name": "CartPole-v1",
+}
+run = wandb.init(
+    project="sb3",
+    config=config,
+    sync_tensorboard=True,
+    monitor_gym=True,
+    save_code=True,
+)
+
+env = gym.make("CartPole-v1", render_mode="rgb_array")
+
+model = A2C("MlpPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}")
+#model = A2C("MlpPolicy", env, )
+model.learn(total_timesteps=10_000, callback=WandbCallback(gradient_save_freq=100,model_save_path=f"models/{run.id}",verbose=2,),)
+#model.learn(total_timesteps=10_000)
+vec_env = model.get_env()
+obs = vec_env.reset()
+for i in range(1000):
+    action, _state = model.predict(obs, deterministic=True)
+    obs, reward, done, info = vec_env.step(action)
+    vec_env.render("human")
+    # VecEnv resets automatically
+    # if done:
+    #   obs = vec_env.reset()
+
+run.finish()
+
+package_to_hub(model=model, 
+               model_name="CartPole-v1",
+               model_architecture="A2C",
+               env_id="CartPole-v1",
+               eval_env=env,
+               repo_id="lennartoe/Cartpole-v1",
+               commit_message="First commit")
\ No newline at end of file
diff --git a/evaluate_reinforce_cartpole.py.py b/evaluate_reinforce_cartpole.py.py
new file mode 100644
index 0000000..01479f7
--- /dev/null
+++ b/evaluate_reinforce_cartpole.py.py
@@ -0,0 +1,32 @@
+import gymnasium as gym
+import torch
+from reinforce_cartpole import Policy
+# Create the environment
+env = gym.make("CartPole-v1", render_mode="human")
+
+# Reset the environment and get the initial observation
+observation = env.reset()[0]
+
+policy = Policy()
+# load learned policy
+policy.load_state_dict(torch.load('policy.pth', weights_only=True))
+policy.eval()
+
+for _ in range(200):
+    # sample action from policy
+    print(observation)
+    print(torch.from_numpy(observation).float())
+    action_probs = policy(torch.from_numpy(observation).float())
+    action = torch.distributions.Categorical(action_probs).sample()
+    # Apply the action to the environment
+    # Returns next observation, reward, done signal (indicating
+    # if the episode has ended), and an additional info dictionary
+    observation, reward, terminated, truncated, info = env.step(action.numpy())
+    # Render the environment to visualize the agent's behavior
+    env.render()
+    print(terminated or truncated)
+    if terminated or truncated: 
+        # Terminated before max step
+        break
+
+env.close()
diff --git a/policy.pth b/policy.pth
new file mode 100644
index 0000000000000000000000000000000000000000..970f5c065a6275f6d9a5038936d3e7bff9adc6ac
GIT binary patch
literal 5512
zcmWIWW@cev;NW1u0Q?NX3<dc)naP#<DTyVCdIi}zZch9RQK+DSDLFYmCnq(zBr`v+
zn9IK?CABCu#U(SjgsYH2GpLYJBZ7;8fgvr~P_H~SGd-iEkSVz&zbH9FFTS)SGpCp<
zz9==RG&3h9z9coTIKL>q%!sRySwka&oq>U&xFo+QF+H`A1)_z^Ehj&*Bp4#dRmd8_
z$iTp0P{<a+%)r2qTu@rb?#sYbQpgeL&EVbO&C*uL>CNKJ*jC8ZnE`SNcS#{nP$6#w
z$gd#xBxNQR7xHB=YiMM!WIzlvEacB%_iiXD6absiS}5qv*jgwA(Jx$5C=ygC3iXc>
z%s*l%CKwfpXR!M+FqRZbfQ(@BZfGl%^lkt-LkePybV;F1P@yc;8Af_wXUL%#Vq7Sn
z!S2meQm6nnqqR^G<P0T<e&v!vm7qe^(%M3`2u=nDhWOmnl0;Ap6{>^1si6@8ipxUH
z+CnXGD3+xb6=&w>6>57k7K4K-E!nV8r?yZx11!SAz`#(Hnp0Y+S6iqL5dwvMN>P46
zerZXeL2aQSL=+UtX~{-~Mzw{;5D}2N;{3Fd+{B7PliEU4uxM#<l3su}J4gKS4ZdPb
z3=ANQJCzAA=)h80Qe{bMJ}CXTIccCLGt@+f#LLMnDalC%+1G0OSdxu_fdPbZ+sDgb
z53>&xX8HzhPS!+QhGZ{{X<6dAXQ!s?zIKU%eY1L1_h0l7+H322*tX(f_ujot^KG4O
zUfw5HSZin3T)6++oISQYlU?^Y?|W*unk&K1?yHa8lbo)7P2z(4b_s~>tH1uxwmV?v
zK9Bk<cGGU?+N~~LYNvd-#jaR&^<J%YJ@$`V8SFRh&e>OVsL4iKmdB3$45xj?!P$0i
z7QVGL*56|<%x7gUci(xRTnPLAualJb<vzZ>FYD(`n|NNoy)`Ee*fRdPVcjgUZ_mRE
zO?GRJH16BF^7}qr;Ul&`GeY;8#H`+HvGa=EggsO3%uTQEyTdubt}|)LJ}wPbyWUAz
zd!sHa-1o1z$L`H?rhS4tZtVKOC%XTlgV}yN6+s)%xx4p;y^picFP^aXHDlPmi)Yu`
zndo)w%Vj)ZcW>#lJ+t1uw>$Z$#oq1qGdltKH~UJhIrqzJEA3|t%-Gk?$ZoHkaKTQ(
zU&>ylSj0Ym)s}q=9Qf=OtrFP(L59mtM`!UquKc5WKVEvi@A<*<eFfgr?1VeM?wfAZ
zu-|F6$o{q|yZ2owR<^fkPPQ+4FJ<pl;%#@H|JA<uzk0Szg}&C_`xn@**r;YVVa9yB
zzo|3#>HJ^1H>9z}Zpo)*b}~yJ*&SZpZD*e&y07Ngt$nRp{Wj&-+4hLv-fH_?ecHbN
zO<VS@`@v+tnCHHo>G>nJGZ*sMn;x2FSN%uX{@B|?cISKf_eE=6+Lv(n|2|zsm;IXm
zCH7mHi`!MkH0|vZTD^})=H5PuJpX;pVf_1ZETs1Lsmtt*RGw%zQJrOf^rrOvS($SC
zuT4H|C*Y^JZ+4ift-;F8cKj3f+NQB9*=93U+WN&WvwI~z$@ZO0%f9lSlXkaaTK6@q
zS!SEF?w#FfaW?x|ed+tIH}~4j`K!6_{)Hmjm)>P|jX!4D)vF)g=j3_TI(Y6Y+qFvT
z_I)h3+gH3-WB-f|+w3eJthcM#a%k_fgBA8%wefZ$tJL<UMR)JJ{Q9yT_sQF~=^EQ?
zdk$Ibb5rHp@8|Y>@2g{s_Sf!i-S?^F+dg-bp1tzgEA4uhaqhPhklrUTuhnipTb*sg
zTNnGV-yZf`KMC5~Ff6cjOA)ZoT|3uaKTF1*Z`~!^<j?Q+-h487|I@pQ`&k%L_C9iW
zYpY@Q)NW>#^1e8+m$t0aAJ|zwxMufX`OrT0U#IuhvM=8A|JpUXj}sr*J-zh6?)#}_
zdzTwD?5|$Rx!?E4-F=a^68nCaY}vPDN9g{Ei*)u^&6C})(0F3+AKt!wf&~Tpc=oN|
zd*#(byRsD@?D$l=?994K_N_g7&-Q)j7Q4M1YI{!Fp5FVN|Do-kxrgk`R<!N2OF6SI
z@=oLat1A@t+qLhvEeP=6cl+z_eG5wV+pV@2wJTzMvTvnJzP+!2roAZ-lf9<ghJ6{|
zCfiKVZ?WtAV7hP9i%0v+LL&D^BzD=l>h9QQu|s(O&B(0%(N!Dw-7q}5FRISLuJ1$8
zes;^}Hfe8f?a{tBYag4i*nYVi9{XP1c(~8|jo5zWZ!h+3Y7ySw{HV`XJY(^`RX&?`
z_stEq+xMVs-@ea(_o+V=-5=d3yPvo9u&rA%|Gv4#;`{RL6z$%(E!lVO#jd?~M1<}4
z-&k&!aOmJZCf2=s`y%!2<+QZzMVX`QCKdDSpCV;rzq4w$ox{5od#B#ty)W3eb${TO
zpSFf^r|gQF-q?LoSZb#jI>GM#v|V=ATpRbv%$BoD`o&~d$iT2~O~hr}H-EMEJBIM=
z3uP7C*Hd)F?%9uw{jo;T`{ZNW_XUTB?+c!KWnW+y%YLg#TlV_A-n38X;{5%7i}vn&
zb1H0at-$hq_ZQ8vFO9lrdoc9$p7RrR_uAL&-J33<zi;k5_I;0@%(Z)Z=f}Q^Kqb2j
zW+l6_Mmf8QoEz=c#o6u8eh|0YeZbA`1Y@_o7VA?x=P%0p<SuB~t=z|I_vgU=eUV&t
zdyS$s?91$T+A~<(va{pQwXOQgU>~p|%l=>HMZ34w-1`DI1>61eRkDw*^|EK$)Ut1%
z_nf`FZwmGu6_~y6bM^$=ZOr`p^o0`b101*7IR>BF_aP?3PO`ykzx0Mu+a;k-_nx^}
zw2$eZ^nUL5>HD~1nC#Xj?cS%fZ}Yx2;*oo~CZ61PzxDFIw|eGwtRj2&?ViZD|8lsX
ztzSpv{?<!A`-Nh6+a8<oaqo=lj&>aT*X=tVxN=|R!q$BU#rW)e6?g2jvuD_sAimq~
zn}h1U>ef?s)!dWz=}QIL=_MuE$H!f>v$r^~ulZcPol&OI{^w^D_h+49wNCu?-&TUR
zfA5=HC+sYhxcA?i!Ms20&4qoQf7<ump0~|b;pi#5osIeSx`|ut9<DF2+mwEJ@5_?<
zeLt9g*#^r5?_DIpYX9c%V*7?Wf9!tVUT&xFUS}7$ykWmq>xz9_1Xt|i_|CoW*wGLB
zLhTv$pUz_1Yx>*L-k5WPoxPs6{kGj)`%D@e>@si7u@f}CYiFP8yf5R#j6Ep<C+*bL
zAKG;v<+YP|zQAVDC&vA`GJSSu*Bst+xAD&2`27>?#MHOhMZGSv+xC3@o`b24`vhyU
z>;lri*j~x_ZOgNB_C5=ZdA8qNH}B0XxoUUvf7U+kLL<A64>s>RE*xn0A>huw*S)`O
zTplXyPi1r5AI}}RKfzIae}(ouyL0ap?7zSIV|T#)h~2++pZBOX3D{3km9;<D%x`~m
z?tj}es!{tGPjlL9&*HK_JE3}SYs=g{Vek5FyEYZ?4HC}SXKZ}fE~v=WeoFR3yRV8m
z`(#%=w9^z{X`drpX~%b&eP6kNj-B?=`TM$<*zImOcJ8%vcd}Qg^|POGaIPJw<<WIP
zJ(B^{s^nl`z}NC%MQM2$x;Zg2L><{S(@x#jYTpA{4f|#H)b@Kx3ha-$y4OagDZy_0
ziCmjE$vpN!bM5TBL}m6_t!dl4q<G)nqbKtADGGMk^<^KkQwq6dYrj={zu$!*yOO{C
z`!bhbvAyB2(C&Jn@xIojnRXvizSzVqxv(!jf_MKf{TKUYHMj1|yvwy;OYDU0S$4gB
zQJtUm{`kYW-`q@h-;^V(?7}nU_isCKX5aFh>V4L=;``M*@9+C}yJ6qP15@^Wa6N1H
zK6d`TH|@>4|FhrOce#sk|C`28ySa7Vc5-@L_P2Mf-Ix79+fK#S!A>M*-aaR8BfEx#
zi*_3o0{42k7VSO7DZIbrnWCLB|FeD19*Wt<>h|wrt=zKj{@1m3m$X{;&2rshcfH(k
zUn>8CeMZtR?Tq?d_kA=K*lWC~&1Ty~NxONEd+ZAP73~(q?%Q``?{8anb6Go$qgr-u
zogH@17B$%I-XyzEaBAPaO&+dxv*Ww%LKWoqe>RM_Yg%Gr_m%mY-R*0tc9xpc_OAUq
zYhRbPxLw7)Py4<XP2Cq$HG6Mx(9V6V5yks*GoI`_mVdzZ{BA9~bvhgNB_-{!d$lOp
zZu#p~yZ7A>v@=?B%XZNWYrAOuV|EVWZ2R9zuH4(0yx*2N{KUSs1q}Oz&x-B$Uc7Z*
zN|pD%ox4@-&-HZKhE*N26+hN)6S3~uKI4}OcEt_Nc3#PB`&(StZ9(bB>Dvh=1||ju
z5GFqT7=hD|AiKSt+Iz*louyHBi|6s#wOB3Pqb0Dy_S8Aqeb2OAY{i8q>?_qVu>0uG
zYP)FF65GU3N!tY#QG5EH{oTWQVX|G*y?wT6=0?^k3tIL$_Dr&y|MT14-aU+Vf4yvM
zT|6(^{XFZp@3ipyy*F>L?z{H)pl#2-e%sGmQ}(3aZ`)`0Z0o*z|GwC1?(x|hXnn!9
zCFYIYf&yjRM{NOnSq;-{!^*$yJrw0?>)Z6qcIjVFyR4+gd(EruZ2CL0ZMImhwY@5T
zb?+ASh<#^5xc2{btFz0C548OiD86sqExEmoj~3W&wVq>lAXLGQA#>|qK0%|sEaxBE
z2K(){TN`L-dqBW?-v+sTduOhzx2=7kZ@2h`fSsC@tL@3=C3eTFJ9fuiUuf65)PJv`
z$jaS~*E#l;T0Y&&%BgO5nmKsiyZx(dI2;#S7gwybxp9xpj&YiRz0nS~Jq67^w)vYE
z*$63~-^2eydY@s7qfJJY!@kH5Q};O^i{9&K|H9U6-}ya%RU+(G_T8~Pu=a=Tkzfwn
z+zs38_J?xY>e}zJoi~li?ttdZy*D}o?4tf#?-fg~va```+t+yN?VefdllCQ^{lAA(
zvcPWDV!3@%J3DOIz6R};d2PNYHEY`5OPi9c<yY$19C=V?J6%N4?lALQyGI&M`?}Me
z_g#5^Xs^XHqdkv8`u9GL)!kP%-EPm0Qucj}pPBcyTZQc9TyMOOSKedK>pQ>qeqhnC
zxqrIY?%uXu>vpe&c0t`*`|K^PY^&}z?=>@<Vk^6u*Vf_cGMi^Fc=jE9+iyGN>zlp(
zJxzA6mUiruc4)G@7-hBR7{B~J*T#waS_&HMQal#e{0rV_6T58kUiRmw_N-qoWV`FM
zrrr0i6YYdHuh~~`dwjQ2RrEd`%f<U5<!0?IWpmi8;Fh^}--PwHJ_~2tXf)PZA1;a5
zmvi!`t<jb&8+Yz>+q-vv?M=L{XuI=xm>pZ#Cp)1jIs1Yao!rCxt;%+Gz)IWmZ*T4m
z-qN>EarfeVyRDAfdZp&r9&^aBjp%r?@5v3@Jv^n}cARdrY}=(T?yX*EzR%Nj!oKWd
zCi_l*zqGF`RmxWC$Hjg7T<+{S?hw3Z3NxeaFS!f0a|E6C?Nej7%WmSgnJGGFZ%KmC
zzV2HU`;MLG+JE)djok{>6ZXEc*4t|<XJ^-ZQPb}KJgI$wd$!qiPv2{|?k}I+mdl6s
zZhE@Tw*R5-zS+G-dnYDcvQb*$xGy`cVqex94ci4_3j4N*o!_(5P1Ww>>Dqni;tqTH
z-h}LHIkw5p@2!Jv#LqT6tvieN7(dv$_r(qmJHNOEd;L<QY>hYZ?lbDi+&ejV#$Jzr
z=X+(Mm)i)LzT0cV-?lHp^o8xI`Sa}}l~?U~czM>|dxt03#)QqVJN9Yv-un4p!R3cE
z=fo}sP&b~1l={aQTz(kG`|siVYqb|t4*fk@7YnV%@YG{`3^uTO3_7yl=43&37uuK$
zWb?&&>7md95Rc8W48|~<^}wC?;?$zd#GK5kM9|<<d}fN9lRl{qF)%eVGd4A}G&L|W
zF|;tWG%>U^H#9deFb1)WEQ}4!EG^6p%#A^=oFDDBM1z5W0fYm*89@$)kDMWocJPD7
zUr;dGI2yWU<e>`z6pd+23@|UlheyzLBR4;IQFLEJ=mv+E0lHS?ng>;D05{xXuvSBK
zt;mTERqF&^RINtnT9MNks@4mFs9KHDwIU}JJ`}%72*I^NlO?)d<mi({(K`vD7kk1A
z@MdGvfohRs)`hYdz>Pi-2b60-^gGaqBm)D3g)oQ*%8hKGF){`<Q3xL-9pKFh8lqz0
MVBlZ?sfVZq09epwKL7v#

literal 0
HcmV?d00001

diff --git a/reinforce_cartpole.py b/reinforce_cartpole.py
new file mode 100644
index 0000000..50be37c
--- /dev/null
+++ b/reinforce_cartpole.py
@@ -0,0 +1,121 @@
+import gymnasium as gym
+import torch
+import numpy as np
+
+class Policy(torch.nn.Module):
+    def __init__(self, input_size=4, output_size=2):
+        super(Policy, self).__init__()
+        self.fc1 = torch.nn.Linear(input_size, 128)
+        self.relu = torch.nn.ReLU()
+        self.dropout = torch.nn.Dropout(0.2)
+        self.fc2 = torch.nn.Linear(128, output_size)
+        self.softmax = torch.nn.Softmax(dim=0)
+    
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.relu(x)
+        x = self.dropout(x)
+        x = self.fc2(x)
+        #print(x)
+        x = self.softmax(x)
+        #print(x)
+        return x
+
+
+def main():
+    policy = Policy()
+    optimizer = torch.optim.Adam(policy.parameters(), lr=5e-3)
+
+    # Create the environment
+    env = gym.make("CartPole-v1")
+
+    # Reset the environment and get the initial observation
+
+    gamma = 0.99
+    total_reward = []
+    total_loss = []
+    epochs = 500
+
+    max_steps = env.spec.max_episode_steps
+
+    for _ in range(epochs):
+        print(_)
+        # Reset the environment
+        observation = env.reset()[0]
+        # Reset buffer
+        # rewards = torch.zeros(max_steps)
+        # log_probs = torch.zeros(max_steps)
+        rewards = []
+        log_probs = []
+        for step in range(max_steps):
+            # Select a random action from the action space
+            #print(observation)
+            action_probs = policy(torch.from_numpy(observation).float())
+
+            # Sample an action from the action probabilities
+            action = torch.distributions.Categorical(action_probs).sample()
+            #print("Action")
+            #print(action)
+            # Apply the action to the environment
+            observation, reward, terminated, truncated, info = env.step(action.numpy())
+            #print(observation)
+            # env.render()
+            # does this come before adding to the rewards or after
+            
+            # rewards[step] = reward
+            # log_probs[step] = torch.log(action_probs[action])
+            rewards.append(torch.tensor(reward))
+            log_probs.append(torch.log(action_probs[action]))
+
+            if terminated or truncated:
+                break
+
+        # apply gamma
+        # transform rewards and log_probs into tensors
+        rewards = torch.stack(rewards)
+        log_probs = torch.stack(log_probs)
+        rewards_length = len(rewards)
+        rewards_tensor = torch.zeros(rewards_length, rewards_length)
+        for i in range(rewards_length):
+            for j in range(rewards_length-i):
+                rewards_tensor[i,j] = rewards[i+j]
+        #print(rewards_tensor)
+        for i in range(rewards_length):
+            for j in range(rewards_length):
+                rewards_tensor[i,j] = rewards_tensor[i,j] * np.pow(gamma,j)
+        #print(rewards_tensor)
+        normalized_rewards = torch.sum(rewards_tensor, dim=1) 
+        #print(normalized_rewards)
+        normalized_rewards = normalized_rewards- torch.mean(normalized_rewards)
+        normalized_rewards /= torch.std(normalized_rewards)
+        
+
+        loss = -torch.sum(log_probs * normalized_rewards)
+        total_reward.append(sum(rewards))
+        # optimize
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+        total_loss.append(loss.detach().numpy())
+        # Render the environment to visualize the agent's behavior
+        #env.render()
+
+    # save the model weights
+    torch.save(policy.state_dict(), "policy.pth")
+
+
+    print(total_reward)
+    print(total_loss)
+    env.close()
+
+    # plot the rewards and the loss side by side
+    import matplotlib.pyplot as plt
+    fig, ax = plt.subplots(1,2)
+    ax[0].plot(total_reward)
+    ax[1].plot(total_loss)
+    plt.show()
+
+
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
-- 
GitLab