Skip to content
Snippets Groups Projects
Commit 2c4c5de5 authored by MaximeCerise's avatar MaximeCerise
Browse files

Reinforce_cartpole ok

parents
Branches
No related tags found
No related merge requests found
import matplotlib.pyplot as plt
import pandas as pd
df = pd.read_csv("saves/scores.csv")
plt.figure()
plt.plot(df['iter'], df['score'])
plt.savefig("saves/plot_rewards.png")
import gymnasium as gym
# Create the environment
env = gym.make("CartPole-v1", render_mode="human")
# Reset the environment and get the initial observation
observation = env.reset()
for _ in range(100):
# Select a random action from the action space
action = env.action_space.sample()
# Apply the action to the environment
# Returns next observation, reward, done signal (indicating
# if the episode has ended), and an additional info dictionary
observation, reward, terminated, truncated, info = env.step(action)
# Render the environment to visualize the agent's behavior
env.render()
if terminated:
# Terminated before max step
break
env.close()
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
class PolicyNetwork(nn.Module):
def __init__(self, input_dim, output_dim):
super(PolicyNetwork, self).__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.15)
self.fc2 = nn.Linear(128, output_dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
return self.softmax(x)
def reinforce_cartpole_main():
env = gym.make("CartPole-v1", render_mode = "human")
policy = PolicyNetwork(env.observation_space.shape[0], env.action_space.n)
optimizer = optim.Adam(policy.parameters(), lr=4e-3)
gamma = 0.99
num_episodes = 500
save_scores = []
for episode in range(num_episodes):
state, _ = env.reset()
state = torch.tensor(state, dtype=torch.float32)
log_probs = []
rewards = []
done = False
while not done:
action_probs = policy(state)
distribution = torch.distributions.Categorical(action_probs)
action = distribution.sample()
log_probs.append(distribution.log_prob(action))
next_state, reward, terminated, truncated, _ = env.step(action.item())
rewards.append(reward)
state = torch.tensor(next_state, dtype=torch.float32)
done = terminated or truncated
returns = []
G = 0
for r in reversed(rewards):
G = r + gamma * G
returns.insert(0, G)
#print("r: ", r, "g :", G)
#print("returns :", returns)
returns = torch.tensor(returns)
returns = (returns - returns.mean()) / (returns.std() + 1e-9)
policy_loss = -torch.sum(torch.stack(log_probs) * returns)
optimizer.zero_grad()
policy_loss.backward()
optimizer.step()
save_scores.append(np.sum(rewards))
print(f"Episode {episode}, mean rewards : {np.sum(rewards)}")
torch.save(policy.state_dict(), "saves/policy_cartpole.pth")
print("Scores & model saved !")
df_scores = pd.DataFrame(save_scores)
df_scores.to_csv("saves/scores.csv")
env.close()
if __name__ == "__main__":
reinforce_cartpole_main()
\ No newline at end of file
gymnasium
gymnasium[classic-control]
torch
pandas
numpy==1.26
matplotlib
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment