Skip to content
Snippets Groups Projects
Commit c571cd65 authored by Benyahia Mohammed Oussama's avatar Benyahia Mohammed Oussama
Browse files

Replace evaluate_reinforce_cartpole.ipynb

parent 0cde9152
Branches
No related tags found
No related merge requests found
%% Cell type:code id: tags:
``` python
import gym
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from gym.wrappers import RecordVideo
import os
# Define the model architecture
class Policy_Network(nn.Module):
def __init__(self, num_observations, num_actions):
super(Policy_Network, self).__init__()
self.fc1 = nn.Linear(num_observations, 128)
self.fc2 = nn.Linear(128, num_actions)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return self.softmax(x)
# Load the model
def load_model(model_path, num_observations, num_actions):
model = Policy_Network(num_observations, num_actions)
model.load_state_dict(torch.load(model_path))
model.eval()
return model
# Evaluate the model on the CartPole environment and plot rewards
# Record a single episode and save as a video in the specified path
def record_episode(model, video_folder=r"C:\Users\BYCInfo\Desktop\M2-2.2\machine learning\RL", max_time_steps=500):
os.makedirs(video_folder, exist_ok=True) # Ensure the folder exists
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = RecordVideo(env, video_folder=video_folder, episode_trigger=lambda x: True)
observation, info = env.reset()
done = False
time_steps = 0
while not done and time_steps < max_time_steps:
obs_tensor = torch.tensor(np.array(observation), dtype=torch.float32).unsqueeze(0)
action_probs = model(obs_tensor)
action = np.argmax(action_probs.detach().numpy()) # Choose the action with max probability
observation, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
time_steps += 1
env.close()
print(f"Video saved in: {video_folder}")
# Evaluate the model on multiple episodes
def evaluate_model(model, num_episodes=100, render=False):
max_time_steps = 500
env = gym.make("CartPole-v1")
rewards = []
for episode in range(num_episodes):
observation, info = env.reset()
done = False
total_reward = 0
time_steps = 0
while not done and time_steps < max_time_steps:
obs_tensor = torch.tensor(np.array(observation), dtype=torch.float32).unsqueeze(0)
action_probs = model(obs_tensor)
action = np.argmax(action_probs.detach().numpy()) # Choose the action with max probability
observation, reward, terminated, truncated, info = env.step(action)
total_reward += reward
time_steps += 1
done = terminated or truncated
if render:
env.render()
rewards.append(total_reward) # Store the total reward for the episode
rewards.append(total_reward)
print(f"Episode {episode+1}: Reward = {total_reward}")
env.close()
# Plot the rewards across episodes
# Plot rewards
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_episodes + 1), rewards, marker="o", linestyle="-", color="b", label="Episode Reward")
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.title("Rewards Per Episode")
plt.legend()
plt.grid()
plt.show()
if __name__ == "__main__":
# Create the environment
env = gym.make("CartPole-v1", render_mode="human")
env = gym.make("CartPole-v1")
num_observations = env.observation_space.shape[0]
num_actions = env.action_space.n
# Load the model
# Load the trained model
model = load_model("reinforce_cartpole.pth", num_observations, num_actions)
# Evaluate the model and plot rewards
# Record a single episode and save as a video
record_episode(model)
# Evaluate the model without rendering
evaluate_model(model, render=False)
```
%% Output
c:\PYTHON\lib\site-packages\gym\wrappers\record_video.py:75: UserWarning: WARN: Overwriting existing videos at C:\Users\BYCInfo\Desktop\M2-2.2\machine learning\RL folder (try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)
logger.warn(
MoviePy - Building video C:\Users\BYCInfo\Desktop\M2-2.2\machine learning\RL\rl-video-episode-0.mp4.
MoviePy - Writing video C:\Users\BYCInfo\Desktop\M2-2.2\machine learning\RL\rl-video-episode-0.mp4
MoviePy - Done !
MoviePy - video ready C:\Users\BYCInfo\Desktop\M2-2.2\machine learning\RL\rl-video-episode-0.mp4
Video saved in: C:\Users\BYCInfo\Desktop\M2-2.2\machine learning\RL
Episode 1: Reward = 500.0
Episode 2: Reward = 500.0
Episode 3: Reward = 500.0
Episode 4: Reward = 500.0
Episode 5: Reward = 500.0
Episode 6: Reward = 500.0
Episode 7: Reward = 500.0
Episode 8: Reward = 500.0
Episode 9: Reward = 500.0
Episode 10: Reward = 500.0
Episode 11: Reward = 500.0
Episode 12: Reward = 500.0
Episode 13: Reward = 500.0
Episode 14: Reward = 500.0
Episode 15: Reward = 500.0
Episode 16: Reward = 500.0
Episode 17: Reward = 500.0
Episode 18: Reward = 500.0
Episode 19: Reward = 500.0
Episode 20: Reward = 500.0
Episode 21: Reward = 500.0
Episode 22: Reward = 500.0
Episode 23: Reward = 500.0
Episode 24: Reward = 500.0
Episode 25: Reward = 500.0
Episode 26: Reward = 500.0
Episode 27: Reward = 500.0
Episode 28: Reward = 500.0
Episode 29: Reward = 500.0
Episode 30: Reward = 500.0
Episode 31: Reward = 500.0
Episode 32: Reward = 500.0
Episode 33: Reward = 500.0
Episode 34: Reward = 500.0
Episode 35: Reward = 500.0
Episode 36: Reward = 500.0
Episode 37: Reward = 500.0
Episode 38: Reward = 500.0
Episode 39: Reward = 500.0
Episode 40: Reward = 500.0
Episode 41: Reward = 500.0
Episode 42: Reward = 500.0
Episode 43: Reward = 500.0
Episode 44: Reward = 500.0
Episode 45: Reward = 500.0
Episode 46: Reward = 500.0
Episode 47: Reward = 500.0
Episode 48: Reward = 500.0
Episode 49: Reward = 500.0
Episode 50: Reward = 500.0
Episode 51: Reward = 500.0
Episode 52: Reward = 500.0
Episode 53: Reward = 500.0
Episode 54: Reward = 500.0
Episode 55: Reward = 500.0
Episode 56: Reward = 500.0
Episode 57: Reward = 500.0
Episode 58: Reward = 500.0
Episode 59: Reward = 500.0
Episode 60: Reward = 500.0
Episode 61: Reward = 500.0
Episode 62: Reward = 500.0
Episode 63: Reward = 500.0
Episode 64: Reward = 500.0
Episode 65: Reward = 500.0
Episode 66: Reward = 500.0
Episode 67: Reward = 500.0
Episode 68: Reward = 500.0
Episode 69: Reward = 500.0
Episode 70: Reward = 500.0
Episode 71: Reward = 500.0
Episode 72: Reward = 500.0
Episode 73: Reward = 500.0
Episode 74: Reward = 500.0
Episode 75: Reward = 500.0
Episode 76: Reward = 500.0
Episode 77: Reward = 500.0
Episode 78: Reward = 500.0
Episode 79: Reward = 500.0
Episode 80: Reward = 500.0
Episode 81: Reward = 500.0
Episode 82: Reward = 500.0
Episode 83: Reward = 500.0
Episode 84: Reward = 500.0
Episode 85: Reward = 500.0
Episode 86: Reward = 500.0
Episode 87: Reward = 500.0
Episode 88: Reward = 500.0
Episode 89: Reward = 500.0
Episode 90: Reward = 500.0
Episode 91: Reward = 500.0
Episode 92: Reward = 500.0
Episode 93: Reward = 500.0
Episode 94: Reward = 500.0
Episode 95: Reward = 500.0
Episode 96: Reward = 500.0
Episode 97: Reward = 500.0
Episode 98: Reward = 500.0
Episode 99: Reward = 500.0
Episode 100: Reward = 500.0
%% Cell type:code id: tags:
``` python
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment