diff --git a/reinforce_cartpole.png b/reinforce_cartpole.png deleted file mode 100644 index 8a14610e2263b5a2a77866d3cf0821195bbb5efa..0000000000000000000000000000000000000000 Binary files a/reinforce_cartpole.png and /dev/null differ diff --git a/reinforce_cartpole.py b/reinforce_cartpole.py index d1494dbc3baa418b9c7b2af0329b39e304f16de4..598719db1a9d663dbe995aa09d3f5b5ccc1d5c63 100644 --- a/reinforce_cartpole.py +++ b/reinforce_cartpole.py @@ -28,9 +28,11 @@ optimizer = optim.Adam(model.parameters(), lr=5e-3) # Keep track of the number of rewards for each episodes rewardsByEpisode = [] +lossesByEpisode = [] # training loop for episode in range(500): + print("Episode's number {0} / {1}".format(episode, 500)) # reset the environment state = env.reset() log_probs = [] @@ -68,19 +70,32 @@ for episode in range(500): for log_prob, return_ in zip(log_probs, returns): model_loss.append(-log_prob * return_) model_loss = torch.cat(model_loss).sum() + + lossesByEpisode.append(model_loss.item()) + # update the model optimizer.zero_grad() model_loss.backward() optimizer.step() - print("N° de l'épisode :", episode) - print("Nombre de rewards :", len(rewards)) + # X axis : x = list(range(len(rewardsByEpisode))) +# Plot Rewards by Episodes +plt.figure("Figure 1") plt.xlabel('Episodes N°') plt.ylabel('Number of rewards given') -plt.plot(x, rewardsByEpisode, '--') +plt.plot(x, rewardsByEpisode) +plt.savefig('reinforce_cartpole_reward.png') +plt.show() + +# Plot losses by Episodes +plt.figure("Figure 2") +plt.xlabel("Episodes N°") +plt.ylabel("Loss") +plt.plot(x, lossesByEpisode) +plt.title("Loss by episodes") +plt.savefig("reinforce_cartpole_loss.png") plt.show() -plt.savefig('reinforce_cartpole.png') diff --git a/reinforce_cartpole_loss.png b/reinforce_cartpole_loss.png new file mode 100644 index 0000000000000000000000000000000000000000..9ace00b9a30875dd9347f2ac29f82798f793caa9 Binary files /dev/null and b/reinforce_cartpole_loss.png differ diff --git a/reinforce_cartpole_reward.png b/reinforce_cartpole_reward.png new file mode 100644 index 0000000000000000000000000000000000000000..bedbc299dd67ce91b14259b81a495707f81197ae Binary files /dev/null and b/reinforce_cartpole_reward.png differ