diff --git a/reinforce_cartpole.py b/reinforce_cartpole.py
index 3fee45cfe9cb38244f8bed8a768fc073ab4d1872..01dc63a1ae0eee15ceed406efeba49c14179a3dd 100644
--- a/reinforce_cartpole.py
+++ b/reinforce_cartpole.py
@@ -25,12 +25,6 @@ class PolicyNetwork(nn.Module):
         x = self.fc2(x)
         return self.softmax(x)
 
-# Normalize function
-def normalize_rewards(rewards):
-    rewards = np.array(rewards)
-    rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-9)
-    return rewards
-
 
 if __name__ == "__main__":
     
@@ -63,25 +57,29 @@ if __name__ == "__main__":
         rewards = []
 
         while True:
+            # Compute action probabilities
             state_tensor = torch.from_numpy(state).float().unsqueeze(0)
             action_probs = policy(state_tensor)
+
+            # Sample action
             m = torch.distributions.Categorical(action_probs)
             action = m.sample()
             saved_log_probs.append(m.log_prob(action))
             state, reward, done, _, _ = env.step(action.item())
+            # Step env with action
             rewards.append(reward)
             if done:
                 break
 
-        # Compute returns
-        returns = torch.tensor([sum(rewards[i:] * (0.99 ** np.arange(len(rewards) - i)))
-                                 for i in range(len(rewards))])
+        # Compute and normalize returns
+        returns = torch.tensor(
+            [sum(rewards[i:] * (0.99 ** np.arange(len(rewards) - i))) for i in range(len(rewards))]
+        )
         returns = (returns - returns.mean()) / (returns.std() + 1e-9)
 
         # Compute policy loss and entropy loss
         policy_loss = -torch.stack(saved_log_probs).mul(returns).sum()
         entropy_loss = -0.01 * (action_probs * torch.log(action_probs)).sum(dim=1).mean()
-
         total_loss = policy_loss + entropy_loss
 
         optimizer.zero_grad()