diff --git a/a2c_sb3_cartpole.py b/a2c_sb3_cartpole.py index 2e817e9b35ddbe1ab15f58fd6c506595c78361a5..399e46f27da68e21c9018a6da434d8fed07b4cb8 100644 --- a/a2c_sb3_cartpole.py +++ b/a2c_sb3_cartpole.py @@ -1,10 +1,11 @@ +import numpy as np +from tqdm import tqdm + import gym from stable_baselines3 import A2C from stable_baselines3.common.env_util import make_vec_env -from stable_baselines3.common.evaluation import evaluate_policy import matplotlib.pyplot as plt -import numpy as np -from tqdm import tqdm + if __name__ == "__main__": episodes = 500 @@ -34,9 +35,6 @@ if __name__ == "__main__": episode_rewards.append(episode_reward) - # Log progress - # print(f"Episode: {episode + 1}, Reward: {episode_reward_sum}") - # Save model model.save("a2c_cartpole_model") diff --git a/reinforce_cartpole.py b/reinforce_cartpole.py index 01dc63a1ae0eee15ceed406efeba49c14179a3dd..30b5f006292e5118d924d9b1c1fb5f7bc8cc845f 100644 --- a/reinforce_cartpole.py +++ b/reinforce_cartpole.py @@ -31,7 +31,7 @@ if __name__ == "__main__": # Hyperparameters learning_rate = 5e-3 gamma = 0.99 - episodes = 450 + episodes = 500 # Environment setup env = gym.make("CartPole-v1") # , render_mode="human")