From 36eabce609ea2475cbae87f93a2df8bb9932013d Mon Sep 17 00:00:00 2001 From: number_cruncher <lennart.oestreich@stud.tu-darmstadt.de> Date: Fri, 14 Mar 2025 15:15:07 +0100 Subject: [PATCH] init --- a2c_sb3_cartpole.py | 47 ++++++++++++ evaluate_reinforce_cartpole.py.py | 32 ++++++++ policy.pth | Bin 0 -> 5512 bytes reinforce_cartpole.py | 121 ++++++++++++++++++++++++++++++ 4 files changed, 200 insertions(+) create mode 100644 a2c_sb3_cartpole.py create mode 100644 evaluate_reinforce_cartpole.py.py create mode 100644 policy.pth create mode 100644 reinforce_cartpole.py diff --git a/a2c_sb3_cartpole.py b/a2c_sb3_cartpole.py new file mode 100644 index 0000000..ee1087d --- /dev/null +++ b/a2c_sb3_cartpole.py @@ -0,0 +1,47 @@ +import gymnasium as gym + +from stable_baselines3 import A2C +import wandb +from wandb.integration.sb3 import WandbCallback +from huggingface_sb3 import package_to_hub + + +# from documentation of wandb +config = { + "policy_type": "MlpPolicy", + "total_timesteps": 25000, + "env_name": "CartPole-v1", +} +run = wandb.init( + project="sb3", + config=config, + sync_tensorboard=True, + monitor_gym=True, + save_code=True, +) + +env = gym.make("CartPole-v1", render_mode="rgb_array") + +model = A2C("MlpPolicy", env, verbose=1, tensorboard_log=f"runs/{run.id}") +#model = A2C("MlpPolicy", env, ) +model.learn(total_timesteps=10_000, callback=WandbCallback(gradient_save_freq=100,model_save_path=f"models/{run.id}",verbose=2,),) +#model.learn(total_timesteps=10_000) +vec_env = model.get_env() +obs = vec_env.reset() +for i in range(1000): + action, _state = model.predict(obs, deterministic=True) + obs, reward, done, info = vec_env.step(action) + vec_env.render("human") + # VecEnv resets automatically + # if done: + # obs = vec_env.reset() + +run.finish() + +package_to_hub(model=model, + model_name="CartPole-v1", + model_architecture="A2C", + env_id="CartPole-v1", + eval_env=env, + repo_id="lennartoe/Cartpole-v1", + commit_message="First commit") \ No newline at end of file diff --git a/evaluate_reinforce_cartpole.py.py b/evaluate_reinforce_cartpole.py.py new file mode 100644 index 0000000..01479f7 --- /dev/null +++ b/evaluate_reinforce_cartpole.py.py @@ -0,0 +1,32 @@ +import gymnasium as gym +import torch +from reinforce_cartpole import Policy +# Create the environment +env = gym.make("CartPole-v1", render_mode="human") + +# Reset the environment and get the initial observation +observation = env.reset()[0] + +policy = Policy() +# load learned policy +policy.load_state_dict(torch.load('policy.pth', weights_only=True)) +policy.eval() + +for _ in range(200): + # sample action from policy + print(observation) + print(torch.from_numpy(observation).float()) + action_probs = policy(torch.from_numpy(observation).float()) + action = torch.distributions.Categorical(action_probs).sample() + # Apply the action to the environment + # Returns next observation, reward, done signal (indicating + # if the episode has ended), and an additional info dictionary + observation, reward, terminated, truncated, info = env.step(action.numpy()) + # Render the environment to visualize the agent's behavior + env.render() + print(terminated or truncated) + if terminated or truncated: + # Terminated before max step + break + +env.close() diff --git a/policy.pth b/policy.pth new file mode 100644 index 0000000000000000000000000000000000000000..970f5c065a6275f6d9a5038936d3e7bff9adc6ac GIT binary patch literal 5512 zcmWIWW@cev;NW1u0Q?NX3<dc)naP#<DTyVCdIi}zZch9RQK+DSDLFYmCnq(zBr`v+ zn9IK?CABCu#U(SjgsYH2GpLYJBZ7;8fgvr~P_H~SGd-iEkSVz&zbH9FFTS)SGpCp< zz9==RG&3h9z9coTIKL>q%!sRySwka&oq>U&xFo+QF+H`A1)_z^Ehj&*Bp4#dRmd8_ z$iTp0P{<a+%)r2qTu@rb?#sYbQpgeL&EVbO&C*uL>CNKJ*jC8ZnE`SNcS#{nP$6#w z$gd#xBxNQR7xHB=YiMM!WIzlvEacB%_iiXD6absiS}5qv*jgwA(Jx$5C=ygC3iXc> z%s*l%CKwfpXR!M+FqRZbfQ(@BZfGl%^lkt-LkePybV;F1P@yc;8Af_wXUL%#Vq7Sn z!S2meQm6nnqqR^G<P0T<e&v!vm7qe^(%M3`2u=nDhWOmnl0;Ap6{>^1si6@8ipxUH z+CnXGD3+xb6=&w>6>57k7K4K-E!nV8r?yZx11!SAz`#(Hnp0Y+S6iqL5dwvMN>P46 zerZXeL2aQSL=+UtX~{-~Mzw{;5D}2N;{3Fd+{B7PliEU4uxM#<l3su}J4gKS4ZdPb z3=ANQJCzAA=)h80Qe{bMJ}CXTIccCLGt@+f#LLMnDalC%+1G0OSdxu_fdPbZ+sDgb z53>&xX8HzhPS!+QhGZ{{X<6dAXQ!s?zIKU%eY1L1_h0l7+H322*tX(f_ujot^KG4O zUfw5HSZin3T)6++oISQYlU?^Y?|W*unk&K1?yHa8lbo)7P2z(4b_s~>tH1uxwmV?v zK9Bk<cGGU?+N~~LYNvd-#jaR&^<J%YJ@$`V8SFRh&e>OVsL4iKmdB3$45xj?!P$0i z7QVGL*56|<%x7gUci(xRTnPLAualJb<vzZ>FYD(`n|NNoy)`Ee*fRdPVcjgUZ_mRE zO?GRJH16BF^7}qr;Ul&`GeY;8#H`+HvGa=EggsO3%uTQEyTdubt}|)LJ}wPbyWUAz zd!sHa-1o1z$L`H?rhS4tZtVKOC%XTlgV}yN6+s)%xx4p;y^picFP^aXHDlPmi)Yu` zndo)w%Vj)ZcW>#lJ+t1uw>$Z$#oq1qGdltKH~UJhIrqzJEA3|t%-Gk?$ZoHkaKTQ( zU&>ylSj0Ym)s}q=9Qf=OtrFP(L59mtM`!UquKc5WKVEvi@A<*<eFfgr?1VeM?wfAZ zu-|F6$o{q|yZ2owR<^fkPPQ+4FJ<pl;%#@H|JA<uzk0Szg}&C_`xn@**r;YVVa9yB zzo|3#>HJ^1H>9z}Zpo)*b}~yJ*&SZpZD*e&y07Ngt$nRp{Wj&-+4hLv-fH_?ecHbN zO<VS@`@v+tnCHHo>G>nJGZ*sMn;x2FSN%uX{@B|?cISKf_eE=6+Lv(n|2|zsm;IXm zCH7mHi`!MkH0|vZTD^})=H5PuJpX;pVf_1ZETs1Lsmtt*RGw%zQJrOf^rrOvS($SC zuT4H|C*Y^JZ+4ift-;F8cKj3f+NQB9*=93U+WN&WvwI~z$@ZO0%f9lSlXkaaTK6@q zS!SEF?w#FfaW?x|ed+tIH}~4j`K!6_{)Hmjm)>P|jX!4D)vF)g=j3_TI(Y6Y+qFvT z_I)h3+gH3-WB-f|+w3eJthcM#a%k_fgBA8%wefZ$tJL<UMR)JJ{Q9yT_sQF~=^EQ? zdk$Ibb5rHp@8|Y>@2g{s_Sf!i-S?^F+dg-bp1tzgEA4uhaqhPhklrUTuhnipTb*sg zTNnGV-yZf`KMC5~Ff6cjOA)ZoT|3uaKTF1*Z`~!^<j?Q+-h487|I@pQ`&k%L_C9iW zYpY@Q)NW>#^1e8+m$t0aAJ|zwxMufX`OrT0U#IuhvM=8A|JpUXj}sr*J-zh6?)#}_ zdzTwD?5|$Rx!?E4-F=a^68nCaY}vPDN9g{Ei*)u^&6C})(0F3+AKt!wf&~Tpc=oN| zd*#(byRsD@?D$l=?994K_N_g7&-Q)j7Q4M1YI{!Fp5FVN|Do-kxrgk`R<!N2OF6SI z@=oLat1A@t+qLhvEeP=6cl+z_eG5wV+pV@2wJTzMvTvnJzP+!2roAZ-lf9<ghJ6{| zCfiKVZ?WtAV7hP9i%0v+LL&D^BzD=l>h9QQu|s(O&B(0%(N!Dw-7q}5FRISLuJ1$8 zes;^}Hfe8f?a{tBYag4i*nYVi9{XP1c(~8|jo5zWZ!h+3Y7ySw{HV`XJY(^`RX&?` z_stEq+xMVs-@ea(_o+V=-5=d3yPvo9u&rA%|Gv4#;`{RL6z$%(E!lVO#jd?~M1<}4 z-&k&!aOmJZCf2=s`y%!2<+QZzMVX`QCKdDSpCV;rzq4w$ox{5od#B#ty)W3eb${TO zpSFf^r|gQF-q?LoSZb#jI>GM#v|V=ATpRbv%$BoD`o&~d$iT2~O~hr}H-EMEJBIM= z3uP7C*Hd)F?%9uw{jo;T`{ZNW_XUTB?+c!KWnW+y%YLg#TlV_A-n38X;{5%7i}vn& zb1H0at-$hq_ZQ8vFO9lrdoc9$p7RrR_uAL&-J33<zi;k5_I;0@%(Z)Z=f}Q^Kqb2j zW+l6_Mmf8QoEz=c#o6u8eh|0YeZbA`1Y@_o7VA?x=P%0p<SuB~t=z|I_vgU=eUV&t zdyS$s?91$T+A~<(va{pQwXOQgU>~p|%l=>HMZ34w-1`DI1>61eRkDw*^|EK$)Ut1% z_nf`FZwmGu6_~y6bM^$=ZOr`p^o0`b101*7IR>BF_aP?3PO`ykzx0Mu+a;k-_nx^} zw2$eZ^nUL5>HD~1nC#Xj?cS%fZ}Yx2;*oo~CZ61PzxDFIw|eGwtRj2&?ViZD|8lsX ztzSpv{?<!A`-Nh6+a8<oaqo=lj&>aT*X=tVxN=|R!q$BU#rW)e6?g2jvuD_sAimq~ zn}h1U>ef?s)!dWz=}QIL=_MuE$H!f>v$r^~ulZcPol&OI{^w^D_h+49wNCu?-&TUR zfA5=HC+sYhxcA?i!Ms20&4qoQf7<ump0~|b;pi#5osIeSx`|ut9<DF2+mwEJ@5_?< zeLt9g*#^r5?_DIpYX9c%V*7?Wf9!tVUT&xFUS}7$ykWmq>xz9_1Xt|i_|CoW*wGLB zLhTv$pUz_1Yx>*L-k5WPoxPs6{kGj)`%D@e>@si7u@f}CYiFP8yf5R#j6Ep<C+*bL zAKG;v<+YP|zQAVDC&vA`GJSSu*Bst+xAD&2`27>?#MHOhMZGSv+xC3@o`b24`vhyU z>;lri*j~x_ZOgNB_C5=ZdA8qNH}B0XxoUUvf7U+kLL<A64>s>RE*xn0A>huw*S)`O zTplXyPi1r5AI}}RKfzIae}(ouyL0ap?7zSIV|T#)h~2++pZBOX3D{3km9;<D%x`~m z?tj}es!{tGPjlL9&*HK_JE3}SYs=g{Vek5FyEYZ?4HC}SXKZ}fE~v=WeoFR3yRV8m z`(#%=w9^z{X`drpX~%b&eP6kNj-B?=`TM$<*zImOcJ8%vcd}Qg^|POGaIPJw<<WIP zJ(B^{s^nl`z}NC%MQM2$x;Zg2L><{S(@x#jYTpA{4f|#H)b@Kx3ha-$y4OagDZy_0 ziCmjE$vpN!bM5TBL}m6_t!dl4q<G)nqbKtADGGMk^<^KkQwq6dYrj={zu$!*yOO{C z`!bhbvAyB2(C&Jn@xIojnRXvizSzVqxv(!jf_MKf{TKUYHMj1|yvwy;OYDU0S$4gB zQJtUm{`kYW-`q@h-;^V(?7}nU_isCKX5aFh>V4L=;``M*@9+C}yJ6qP15@^Wa6N1H zK6d`TH|@>4|FhrOce#sk|C`28ySa7Vc5-@L_P2Mf-Ix79+fK#S!A>M*-aaR8BfEx# zi*_3o0{42k7VSO7DZIbrnWCLB|FeD19*Wt<>h|wrt=zKj{@1m3m$X{;&2rshcfH(k zUn>8CeMZtR?Tq?d_kA=K*lWC~&1Ty~NxONEd+ZAP73~(q?%Q``?{8anb6Go$qgr-u zogH@17B$%I-XyzEaBAPaO&+dxv*Ww%LKWoqe>RM_Yg%Gr_m%mY-R*0tc9xpc_OAUq zYhRbPxLw7)Py4<XP2Cq$HG6Mx(9V6V5yks*GoI`_mVdzZ{BA9~bvhgNB_-{!d$lOp zZu#p~yZ7A>v@=?B%XZNWYrAOuV|EVWZ2R9zuH4(0yx*2N{KUSs1q}Oz&x-B$Uc7Z* zN|pD%ox4@-&-HZKhE*N26+hN)6S3~uKI4}OcEt_Nc3#PB`&(StZ9(bB>Dvh=1||ju z5GFqT7=hD|AiKSt+Iz*louyHBi|6s#wOB3Pqb0Dy_S8Aqeb2OAY{i8q>?_qVu>0uG zYP)FF65GU3N!tY#QG5EH{oTWQVX|G*y?wT6=0?^k3tIL$_Dr&y|MT14-aU+Vf4yvM zT|6(^{XFZp@3ipyy*F>L?z{H)pl#2-e%sGmQ}(3aZ`)`0Z0o*z|GwC1?(x|hXnn!9 zCFYIYf&yjRM{NOnSq;-{!^*$yJrw0?>)Z6qcIjVFyR4+gd(EruZ2CL0ZMImhwY@5T zb?+ASh<#^5xc2{btFz0C548OiD86sqExEmoj~3W&wVq>lAXLGQA#>|qK0%|sEaxBE z2K(){TN`L-dqBW?-v+sTduOhzx2=7kZ@2h`fSsC@tL@3=C3eTFJ9fuiUuf65)PJv` z$jaS~*E#l;T0Y&&%BgO5nmKsiyZx(dI2;#S7gwybxp9xpj&YiRz0nS~Jq67^w)vYE z*$63~-^2eydY@s7qfJJY!@kH5Q};O^i{9&K|H9U6-}ya%RU+(G_T8~Pu=a=Tkzfwn z+zs38_J?xY>e}zJoi~li?ttdZy*D}o?4tf#?-fg~va```+t+yN?VefdllCQ^{lAA( zvcPWDV!3@%J3DOIz6R};d2PNYHEY`5OPi9c<yY$19C=V?J6%N4?lALQyGI&M`?}Me z_g#5^Xs^XHqdkv8`u9GL)!kP%-EPm0Qucj}pPBcyTZQc9TyMOOSKedK>pQ>qeqhnC zxqrIY?%uXu>vpe&c0t`*`|K^PY^&}z?=>@<Vk^6u*Vf_cGMi^Fc=jE9+iyGN>zlp( zJxzA6mUiruc4)G@7-hBR7{B~J*T#waS_&HMQal#e{0rV_6T58kUiRmw_N-qoWV`FM zrrr0i6YYdHuh~~`dwjQ2RrEd`%f<U5<!0?IWpmi8;Fh^}--PwHJ_~2tXf)PZA1;a5 zmvi!`t<jb&8+Yz>+q-vv?M=L{XuI=xm>pZ#Cp)1jIs1Yao!rCxt;%+Gz)IWmZ*T4m z-qN>EarfeVyRDAfdZp&r9&^aBjp%r?@5v3@Jv^n}cARdrY}=(T?yX*EzR%Nj!oKWd zCi_l*zqGF`RmxWC$Hjg7T<+{S?hw3Z3NxeaFS!f0a|E6C?Nej7%WmSgnJGGFZ%KmC zzV2HU`;MLG+JE)djok{>6ZXEc*4t|<XJ^-ZQPb}KJgI$wd$!qiPv2{|?k}I+mdl6s zZhE@Tw*R5-zS+G-dnYDcvQb*$xGy`cVqex94ci4_3j4N*o!_(5P1Ww>>Dqni;tqTH z-h}LHIkw5p@2!Jv#LqT6tvieN7(dv$_r(qmJHNOEd;L<QY>hYZ?lbDi+&ejV#$Jzr z=X+(Mm)i)LzT0cV-?lHp^o8xI`Sa}}l~?U~czM>|dxt03#)QqVJN9Yv-un4p!R3cE z=fo}sP&b~1l={aQTz(kG`|siVYqb|t4*fk@7YnV%@YG{`3^uTO3_7yl=43&37uuK$ zWb?&&>7md95Rc8W48|~<^}wC?;?$zd#GK5kM9|<<d}fN9lRl{qF)%eVGd4A}G&L|W zF|;tWG%>U^H#9deFb1)WEQ}4!EG^6p%#A^=oFDDBM1z5W0fYm*89@$)kDMWocJPD7 zUr;dGI2yWU<e>`z6pd+23@|UlheyzLBR4;IQFLEJ=mv+E0lHS?ng>;D05{xXuvSBK zt;mTERqF&^RINtnT9MNks@4mFs9KHDwIU}JJ`}%72*I^NlO?)d<mi({(K`vD7kk1A z@MdGvfohRs)`hYdz>Pi-2b60-^gGaqBm)D3g)oQ*%8hKGF){`<Q3xL-9pKFh8lqz0 MVBlZ?sfVZq09epwKL7v# literal 0 HcmV?d00001 diff --git a/reinforce_cartpole.py b/reinforce_cartpole.py new file mode 100644 index 0000000..50be37c --- /dev/null +++ b/reinforce_cartpole.py @@ -0,0 +1,121 @@ +import gymnasium as gym +import torch +import numpy as np + +class Policy(torch.nn.Module): + def __init__(self, input_size=4, output_size=2): + super(Policy, self).__init__() + self.fc1 = torch.nn.Linear(input_size, 128) + self.relu = torch.nn.ReLU() + self.dropout = torch.nn.Dropout(0.2) + self.fc2 = torch.nn.Linear(128, output_size) + self.softmax = torch.nn.Softmax(dim=0) + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.dropout(x) + x = self.fc2(x) + #print(x) + x = self.softmax(x) + #print(x) + return x + + +def main(): + policy = Policy() + optimizer = torch.optim.Adam(policy.parameters(), lr=5e-3) + + # Create the environment + env = gym.make("CartPole-v1") + + # Reset the environment and get the initial observation + + gamma = 0.99 + total_reward = [] + total_loss = [] + epochs = 500 + + max_steps = env.spec.max_episode_steps + + for _ in range(epochs): + print(_) + # Reset the environment + observation = env.reset()[0] + # Reset buffer + # rewards = torch.zeros(max_steps) + # log_probs = torch.zeros(max_steps) + rewards = [] + log_probs = [] + for step in range(max_steps): + # Select a random action from the action space + #print(observation) + action_probs = policy(torch.from_numpy(observation).float()) + + # Sample an action from the action probabilities + action = torch.distributions.Categorical(action_probs).sample() + #print("Action") + #print(action) + # Apply the action to the environment + observation, reward, terminated, truncated, info = env.step(action.numpy()) + #print(observation) + # env.render() + # does this come before adding to the rewards or after + + # rewards[step] = reward + # log_probs[step] = torch.log(action_probs[action]) + rewards.append(torch.tensor(reward)) + log_probs.append(torch.log(action_probs[action])) + + if terminated or truncated: + break + + # apply gamma + # transform rewards and log_probs into tensors + rewards = torch.stack(rewards) + log_probs = torch.stack(log_probs) + rewards_length = len(rewards) + rewards_tensor = torch.zeros(rewards_length, rewards_length) + for i in range(rewards_length): + for j in range(rewards_length-i): + rewards_tensor[i,j] = rewards[i+j] + #print(rewards_tensor) + for i in range(rewards_length): + for j in range(rewards_length): + rewards_tensor[i,j] = rewards_tensor[i,j] * np.pow(gamma,j) + #print(rewards_tensor) + normalized_rewards = torch.sum(rewards_tensor, dim=1) + #print(normalized_rewards) + normalized_rewards = normalized_rewards- torch.mean(normalized_rewards) + normalized_rewards /= torch.std(normalized_rewards) + + + loss = -torch.sum(log_probs * normalized_rewards) + total_reward.append(sum(rewards)) + # optimize + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss.append(loss.detach().numpy()) + # Render the environment to visualize the agent's behavior + #env.render() + + # save the model weights + torch.save(policy.state_dict(), "policy.pth") + + + print(total_reward) + print(total_loss) + env.close() + + # plot the rewards and the loss side by side + import matplotlib.pyplot as plt + fig, ax = plt.subplots(1,2) + ax[0].plot(total_reward) + ax[1].plot(total_loss) + plt.show() + + + +if __name__ == "__main__": + main() \ No newline at end of file -- GitLab