Skip to content
Snippets Groups Projects
Commit c0886cbe authored by cgerest's avatar cgerest
Browse files

Update reinforce cartpole

parent 2e15fdbf
No related branches found
No related tags found
No related merge requests found
......@@ -25,12 +25,6 @@ class PolicyNetwork(nn.Module):
x = self.fc2(x)
return self.softmax(x)
# Normalize function
def normalize_rewards(rewards):
rewards = np.array(rewards)
rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-9)
return rewards
if __name__ == "__main__":
......@@ -63,25 +57,29 @@ if __name__ == "__main__":
rewards = []
while True:
# Compute action probabilities
state_tensor = torch.from_numpy(state).float().unsqueeze(0)
action_probs = policy(state_tensor)
# Sample action
m = torch.distributions.Categorical(action_probs)
action = m.sample()
saved_log_probs.append(m.log_prob(action))
state, reward, done, _, _ = env.step(action.item())
# Step env with action
rewards.append(reward)
if done:
break
# Compute returns
returns = torch.tensor([sum(rewards[i:] * (0.99 ** np.arange(len(rewards) - i)))
for i in range(len(rewards))])
# Compute and normalize returns
returns = torch.tensor(
[sum(rewards[i:] * (0.99 ** np.arange(len(rewards) - i))) for i in range(len(rewards))]
)
returns = (returns - returns.mean()) / (returns.std() + 1e-9)
# Compute policy loss and entropy loss
policy_loss = -torch.stack(saved_log_probs).mul(returns).sum()
entropy_loss = -0.01 * (action_probs * torch.log(action_probs)).sum(dim=1).mean()
total_loss = policy_loss + entropy_loss
optimizer.zero_grad()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment