Skip to content
Snippets Groups Projects
Select Git revision
  • 28764cf809d1b06623d7f69348fa382818f10770
  • main default protected
2 results

README.md

Blame
  • Forked from Dellandrea Emmanuel / MSO_3_4-TD2
    Source project has a limited visibility.
    reinforce_cartpole.py 3.68 KiB
    import gymnasium as gym
    import torch
    import numpy as np
    
    class Policy(torch.nn.Module):
        def __init__(self, input_size=4, output_size=2):
            super(Policy, self).__init__()
            self.fc1 = torch.nn.Linear(input_size, 128)
            self.relu = torch.nn.ReLU()
            self.dropout = torch.nn.Dropout(0.2)
            self.fc2 = torch.nn.Linear(128, output_size)
            self.softmax = torch.nn.Softmax(dim=0)
        
        def forward(self, x):
            x = self.fc1(x)
            x = self.relu(x)
            x = self.dropout(x)
            x = self.fc2(x)
            #print(x)
            x = self.softmax(x)
            #print(x)
            return x
    
    
    def main():
        policy = Policy()
        optimizer = torch.optim.Adam(policy.parameters(), lr=5e-3)
    
        # Create the environment
        env = gym.make("CartPole-v1")
    
        # Reset the environment and get the initial observation
    
        gamma = 0.99
        total_reward = []
        total_loss = []
        epochs = 500
    
        max_steps = env.spec.max_episode_steps
    
        for _ in range(epochs):
            print(_)
            # Reset the environment
            observation = env.reset()[0]
            # Reset buffer
            # rewards = torch.zeros(max_steps)
            # log_probs = torch.zeros(max_steps)
            rewards = []
            log_probs = []
            for step in range(max_steps):
                # Select a random action from the action space
                #print(observation)
                action_probs = policy(torch.from_numpy(observation).float())
    
                # Sample an action from the action probabilities
                action = torch.distributions.Categorical(action_probs).sample()
                #print("Action")
                #print(action)
                # Apply the action to the environment
                observation, reward, terminated, truncated, info = env.step(action.numpy())
                #print(observation)
                # env.render()
                # does this come before adding to the rewards or after
                
                # rewards[step] = reward
                # log_probs[step] = torch.log(action_probs[action])
                rewards.append(torch.tensor(reward))
                log_probs.append(torch.log(action_probs[action]))
    
                if terminated or truncated:
                    break
    
            # apply gamma
            # transform rewards and log_probs into tensors
            rewards = torch.stack(rewards)
            log_probs = torch.stack(log_probs)
            rewards_length = len(rewards)
            rewards_tensor = torch.zeros(rewards_length, rewards_length)
            for i in range(rewards_length):
                for j in range(rewards_length-i):
                    rewards_tensor[i,j] = rewards[i+j]
            #print(rewards_tensor)
            for i in range(rewards_length):
                for j in range(rewards_length):
                    rewards_tensor[i,j] = rewards_tensor[i,j] * np.pow(gamma,j)
            #print(rewards_tensor)
            normalized_rewards = torch.sum(rewards_tensor, dim=1) 
            #print(normalized_rewards)
            normalized_rewards = normalized_rewards- torch.mean(normalized_rewards)
            normalized_rewards /= torch.std(normalized_rewards)
            
    
            loss = -torch.sum(log_probs * normalized_rewards)
            total_reward.append(sum(rewards))
            # optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss.append(loss.detach().numpy())
            # Render the environment to visualize the agent's behavior
            #env.render()
    
        # save the model weights
        torch.save(policy.state_dict(), "policy.pth")
    
    
        print(total_reward)
        print(total_loss)
        env.close()
    
        # plot the rewards and the loss side by side
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(1,2)
        ax[0].plot(total_reward)
        ax[1].plot(total_loss)
        plt.show()
    
    
    
    if __name__ == "__main__":
        main()