diff --git a/reinforce_cartpole.py b/reinforce_cartpole.py index 3fee45cfe9cb38244f8bed8a768fc073ab4d1872..01dc63a1ae0eee15ceed406efeba49c14179a3dd 100644 --- a/reinforce_cartpole.py +++ b/reinforce_cartpole.py @@ -25,12 +25,6 @@ class PolicyNetwork(nn.Module): x = self.fc2(x) return self.softmax(x) -# Normalize function -def normalize_rewards(rewards): - rewards = np.array(rewards) - rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-9) - return rewards - if __name__ == "__main__": @@ -63,25 +57,29 @@ if __name__ == "__main__": rewards = [] while True: + # Compute action probabilities state_tensor = torch.from_numpy(state).float().unsqueeze(0) action_probs = policy(state_tensor) + + # Sample action m = torch.distributions.Categorical(action_probs) action = m.sample() saved_log_probs.append(m.log_prob(action)) state, reward, done, _, _ = env.step(action.item()) + # Step env with action rewards.append(reward) if done: break - # Compute returns - returns = torch.tensor([sum(rewards[i:] * (0.99 ** np.arange(len(rewards) - i))) - for i in range(len(rewards))]) + # Compute and normalize returns + returns = torch.tensor( + [sum(rewards[i:] * (0.99 ** np.arange(len(rewards) - i))) for i in range(len(rewards))] + ) returns = (returns - returns.mean()) / (returns.std() + 1e-9) # Compute policy loss and entropy loss policy_loss = -torch.stack(saved_log_probs).mul(returns).sum() entropy_loss = -0.01 * (action_probs * torch.log(action_probs)).sum(dim=1).mean() - total_loss = policy_loss + entropy_loss optimizer.zero_grad()