diff --git a/reinforce_cartpole.py b/reinforce_cartpole.py new file mode 100644 index 0000000000000000000000000000000000000000..ba52806cbe1ebbf3d1e21f0eb56fe5bfc71c55ff --- /dev/null +++ b/reinforce_cartpole.py @@ -0,0 +1,77 @@ +import gym, pygame, numpy as np, matplotlib.pyplot as plt +import torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim +from torch.distributions import Categorical + +# Setup the CartPole environment +env = gym.make("CartPole-v1", render_mode="human") + +# Setup the agent as a simple neural network +class Agent(nn.Module) : + def __init__(self) : + super(Agent, self).__init__() + self.FC1 = nn.Linear(env.observation_space.shape[0], 128) + self.FC2 = nn.Linear(128, env.action_space.n) + + def forward(self, x) : + x = self.FC1(x) + x = F.relu(x) + x = F.dropout(x) + x = self.FC2(x) + x = F.softmax(x, dim=1) + return x + +# Creation of the agent +agent = Agent() +rewards_tot = [] + +# Repeat 500 times +for i in range(500): + # Reset the environment + obs = env.reset() + obs = obs[0] if isinstance(obs, tuple) else obs + # Reset the buffer + rewards, log_probs_list, terminated, step = [], [], False, 0 + + # Repeat until the end of the episode + while terminated == False and step < 500: + step += 1 + # Compute action probabilities + obs_tensor = torch.FloatTensor(obs).unsqueeze(0) + log_probs = agent(obs_tensor) + probs = torch.exp(log_probs) + # Sample the action based on the probabilities and store probability + action = torch.multinomial(probs, 1).item() + # Step the environment with the action + new_obs, reward, terminated, _ = env.step(action) + env.render() + # Compute and store the return in the buffer + rewards.append(reward) + log_probs_list.append(log_probs[0, action]) + obs = new_obs + + # Normalize the return + R = 0 + returns = [] + for r_i in rewards[::-1] : + R = r_i + 0.99*R + returns.insert(0, R) + returns = torch.tensor(returns) + returns = 1/(returns.std(dim=0) + 1e-9) * (returns - returns.mean(dim=0)) + rewards_tot.append(sum(rewards)) + # Compute the policy loss + loss = -torch.sum(torch.stack(log_probs_list) * torch.FloatTensor(returns)) + # Update policy with an Adam optimizer + optimizer = optim.Adam(agent.parameters(), lr=5e-3) + optimizer.zero_grad() + loss.backward() + optimizer.step() + +# Close the environment +env.close() +# Reward plot +plt.figure() +plt.plot(rewards_tot) +plt.xlabel('Episode') +plt.ylabel('Reward') +plt.title('Reinforcement rewards') +plt.show() \ No newline at end of file