Skip to content
Snippets Groups Projects
Commit c571cd65 authored by Benyahia Mohammed Oussama's avatar Benyahia Mohammed Oussama
Browse files

Replace evaluate_reinforce_cartpole.ipynb

parent 0cde9152
Branches
No related tags found
No related merge requests found
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import gym import gym
import torch import torch
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from gym.wrappers import RecordVideo
import os
# Define the model architecture # Define the model architecture
class Policy_Network(nn.Module): class Policy_Network(nn.Module):
def __init__(self, num_observations, num_actions): def __init__(self, num_observations, num_actions):
super(Policy_Network, self).__init__() super(Policy_Network, self).__init__()
self.fc1 = nn.Linear(num_observations, 128) self.fc1 = nn.Linear(num_observations, 128)
self.fc2 = nn.Linear(128, num_actions) self.fc2 = nn.Linear(128, num_actions)
self.softmax = nn.Softmax(dim=-1) self.softmax = nn.Softmax(dim=-1)
def forward(self, x): def forward(self, x):
x = torch.relu(self.fc1(x)) x = torch.relu(self.fc1(x))
x = self.fc2(x) x = self.fc2(x)
return self.softmax(x) return self.softmax(x)
# Load the model # Load the model
def load_model(model_path, num_observations, num_actions): def load_model(model_path, num_observations, num_actions):
model = Policy_Network(num_observations, num_actions) model = Policy_Network(num_observations, num_actions)
model.load_state_dict(torch.load(model_path)) model.load_state_dict(torch.load(model_path))
model.eval() model.eval()
return model return model
# Evaluate the model on the CartPole environment and plot rewards # Record a single episode and save as a video in the specified path
def record_episode(model, video_folder=r"C:\Users\BYCInfo\Desktop\M2-2.2\machine learning\RL", max_time_steps=500):
os.makedirs(video_folder, exist_ok=True) # Ensure the folder exists
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = RecordVideo(env, video_folder=video_folder, episode_trigger=lambda x: True)
observation, info = env.reset()
done = False
time_steps = 0
while not done and time_steps < max_time_steps:
obs_tensor = torch.tensor(np.array(observation), dtype=torch.float32).unsqueeze(0)
action_probs = model(obs_tensor)
action = np.argmax(action_probs.detach().numpy()) # Choose the action with max probability
observation, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
time_steps += 1
env.close()
print(f"Video saved in: {video_folder}")
# Evaluate the model on multiple episodes
def evaluate_model(model, num_episodes=100, render=False): def evaluate_model(model, num_episodes=100, render=False):
max_time_steps = 500 max_time_steps = 500
env = gym.make("CartPole-v1") env = gym.make("CartPole-v1")
rewards = [] rewards = []
for episode in range(num_episodes): for episode in range(num_episodes):
observation, info = env.reset() observation, info = env.reset()
done = False done = False
total_reward = 0 total_reward = 0
time_steps = 0 time_steps = 0
while not done and time_steps < max_time_steps: while not done and time_steps < max_time_steps:
obs_tensor = torch.tensor(np.array(observation), dtype=torch.float32).unsqueeze(0) obs_tensor = torch.tensor(np.array(observation), dtype=torch.float32).unsqueeze(0)
action_probs = model(obs_tensor) action_probs = model(obs_tensor)
action = np.argmax(action_probs.detach().numpy()) # Choose the action with max probability action = np.argmax(action_probs.detach().numpy()) # Choose the action with max probability
observation, reward, terminated, truncated, info = env.step(action) observation, reward, terminated, truncated, info = env.step(action)
total_reward += reward total_reward += reward
time_steps += 1 time_steps += 1
done = terminated or truncated done = terminated or truncated
if render: if render:
env.render() env.render()
rewards.append(total_reward) # Store the total reward for the episode rewards.append(total_reward)
print(f"Episode {episode+1}: Reward = {total_reward}") print(f"Episode {episode+1}: Reward = {total_reward}")
env.close() env.close()
# Plot the rewards across episodes # Plot rewards
plt.figure(figsize=(10, 5)) plt.figure(figsize=(10, 5))
plt.plot(range(1, num_episodes + 1), rewards, marker="o", linestyle="-", color="b", label="Episode Reward") plt.plot(range(1, num_episodes + 1), rewards, marker="o", linestyle="-", color="b", label="Episode Reward")
plt.xlabel("Episode") plt.xlabel("Episode")
plt.ylabel("Total Reward") plt.ylabel("Total Reward")
plt.title("Rewards Per Episode") plt.title("Rewards Per Episode")
plt.legend() plt.legend()
plt.grid() plt.grid()
plt.show() plt.show()
if __name__ == "__main__": if __name__ == "__main__":
# Create the environment # Create the environment
env = gym.make("CartPole-v1", render_mode="human") env = gym.make("CartPole-v1")
num_observations = env.observation_space.shape[0] num_observations = env.observation_space.shape[0]
num_actions = env.action_space.n num_actions = env.action_space.n
# Load the model # Load the trained model
model = load_model("reinforce_cartpole.pth", num_observations, num_actions) model = load_model("reinforce_cartpole.pth", num_observations, num_actions)
# Evaluate the model and plot rewards # Record a single episode and save as a video
record_episode(model)
# Evaluate the model without rendering
evaluate_model(model, render=False) evaluate_model(model, render=False)
``` ```
%% Output %% Output
c:\PYTHON\lib\site-packages\gym\wrappers\record_video.py:75: UserWarning: WARN: Overwriting existing videos at C:\Users\BYCInfo\Desktop\M2-2.2\machine learning\RL folder (try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)
logger.warn(
MoviePy - Building video C:\Users\BYCInfo\Desktop\M2-2.2\machine learning\RL\rl-video-episode-0.mp4.
MoviePy - Writing video C:\Users\BYCInfo\Desktop\M2-2.2\machine learning\RL\rl-video-episode-0.mp4
MoviePy - Done !
MoviePy - video ready C:\Users\BYCInfo\Desktop\M2-2.2\machine learning\RL\rl-video-episode-0.mp4
Video saved in: C:\Users\BYCInfo\Desktop\M2-2.2\machine learning\RL
Episode 1: Reward = 500.0 Episode 1: Reward = 500.0
Episode 2: Reward = 500.0 Episode 2: Reward = 500.0
Episode 3: Reward = 500.0 Episode 3: Reward = 500.0
Episode 4: Reward = 500.0 Episode 4: Reward = 500.0
Episode 5: Reward = 500.0 Episode 5: Reward = 500.0
Episode 6: Reward = 500.0 Episode 6: Reward = 500.0
Episode 7: Reward = 500.0 Episode 7: Reward = 500.0
Episode 8: Reward = 500.0 Episode 8: Reward = 500.0
Episode 9: Reward = 500.0 Episode 9: Reward = 500.0
Episode 10: Reward = 500.0 Episode 10: Reward = 500.0
Episode 11: Reward = 500.0 Episode 11: Reward = 500.0
Episode 12: Reward = 500.0 Episode 12: Reward = 500.0
Episode 13: Reward = 500.0 Episode 13: Reward = 500.0
Episode 14: Reward = 500.0 Episode 14: Reward = 500.0
Episode 15: Reward = 500.0 Episode 15: Reward = 500.0
Episode 16: Reward = 500.0 Episode 16: Reward = 500.0
Episode 17: Reward = 500.0 Episode 17: Reward = 500.0
Episode 18: Reward = 500.0 Episode 18: Reward = 500.0
Episode 19: Reward = 500.0 Episode 19: Reward = 500.0
Episode 20: Reward = 500.0 Episode 20: Reward = 500.0
Episode 21: Reward = 500.0 Episode 21: Reward = 500.0
Episode 22: Reward = 500.0 Episode 22: Reward = 500.0
Episode 23: Reward = 500.0 Episode 23: Reward = 500.0
Episode 24: Reward = 500.0 Episode 24: Reward = 500.0
Episode 25: Reward = 500.0 Episode 25: Reward = 500.0
Episode 26: Reward = 500.0 Episode 26: Reward = 500.0
Episode 27: Reward = 500.0 Episode 27: Reward = 500.0
Episode 28: Reward = 500.0 Episode 28: Reward = 500.0
Episode 29: Reward = 500.0 Episode 29: Reward = 500.0
Episode 30: Reward = 500.0 Episode 30: Reward = 500.0
Episode 31: Reward = 500.0 Episode 31: Reward = 500.0
Episode 32: Reward = 500.0 Episode 32: Reward = 500.0
Episode 33: Reward = 500.0 Episode 33: Reward = 500.0
Episode 34: Reward = 500.0 Episode 34: Reward = 500.0
Episode 35: Reward = 500.0 Episode 35: Reward = 500.0
Episode 36: Reward = 500.0 Episode 36: Reward = 500.0
Episode 37: Reward = 500.0 Episode 37: Reward = 500.0
Episode 38: Reward = 500.0 Episode 38: Reward = 500.0
Episode 39: Reward = 500.0 Episode 39: Reward = 500.0
Episode 40: Reward = 500.0 Episode 40: Reward = 500.0
Episode 41: Reward = 500.0 Episode 41: Reward = 500.0
Episode 42: Reward = 500.0 Episode 42: Reward = 500.0
Episode 43: Reward = 500.0 Episode 43: Reward = 500.0
Episode 44: Reward = 500.0 Episode 44: Reward = 500.0
Episode 45: Reward = 500.0 Episode 45: Reward = 500.0
Episode 46: Reward = 500.0 Episode 46: Reward = 500.0
Episode 47: Reward = 500.0 Episode 47: Reward = 500.0
Episode 48: Reward = 500.0 Episode 48: Reward = 500.0
Episode 49: Reward = 500.0 Episode 49: Reward = 500.0
Episode 50: Reward = 500.0 Episode 50: Reward = 500.0
Episode 51: Reward = 500.0 Episode 51: Reward = 500.0
Episode 52: Reward = 500.0 Episode 52: Reward = 500.0
Episode 53: Reward = 500.0 Episode 53: Reward = 500.0
Episode 54: Reward = 500.0 Episode 54: Reward = 500.0
Episode 55: Reward = 500.0 Episode 55: Reward = 500.0
Episode 56: Reward = 500.0 Episode 56: Reward = 500.0
Episode 57: Reward = 500.0 Episode 57: Reward = 500.0
Episode 58: Reward = 500.0 Episode 58: Reward = 500.0
Episode 59: Reward = 500.0 Episode 59: Reward = 500.0
Episode 60: Reward = 500.0 Episode 60: Reward = 500.0
Episode 61: Reward = 500.0 Episode 61: Reward = 500.0
Episode 62: Reward = 500.0 Episode 62: Reward = 500.0
Episode 63: Reward = 500.0 Episode 63: Reward = 500.0
Episode 64: Reward = 500.0 Episode 64: Reward = 500.0
Episode 65: Reward = 500.0 Episode 65: Reward = 500.0
Episode 66: Reward = 500.0 Episode 66: Reward = 500.0
Episode 67: Reward = 500.0 Episode 67: Reward = 500.0
Episode 68: Reward = 500.0 Episode 68: Reward = 500.0
Episode 69: Reward = 500.0 Episode 69: Reward = 500.0
Episode 70: Reward = 500.0 Episode 70: Reward = 500.0
Episode 71: Reward = 500.0 Episode 71: Reward = 500.0
Episode 72: Reward = 500.0 Episode 72: Reward = 500.0
Episode 73: Reward = 500.0 Episode 73: Reward = 500.0
Episode 74: Reward = 500.0 Episode 74: Reward = 500.0
Episode 75: Reward = 500.0 Episode 75: Reward = 500.0
Episode 76: Reward = 500.0 Episode 76: Reward = 500.0
Episode 77: Reward = 500.0 Episode 77: Reward = 500.0
Episode 78: Reward = 500.0 Episode 78: Reward = 500.0
Episode 79: Reward = 500.0 Episode 79: Reward = 500.0
Episode 80: Reward = 500.0 Episode 80: Reward = 500.0
Episode 81: Reward = 500.0 Episode 81: Reward = 500.0
Episode 82: Reward = 500.0 Episode 82: Reward = 500.0
Episode 83: Reward = 500.0 Episode 83: Reward = 500.0
Episode 84: Reward = 500.0 Episode 84: Reward = 500.0
Episode 85: Reward = 500.0 Episode 85: Reward = 500.0
Episode 86: Reward = 500.0 Episode 86: Reward = 500.0
Episode 87: Reward = 500.0 Episode 87: Reward = 500.0
Episode 88: Reward = 500.0 Episode 88: Reward = 500.0
Episode 89: Reward = 500.0 Episode 89: Reward = 500.0
Episode 90: Reward = 500.0 Episode 90: Reward = 500.0
Episode 91: Reward = 500.0 Episode 91: Reward = 500.0
Episode 92: Reward = 500.0 Episode 92: Reward = 500.0
Episode 93: Reward = 500.0 Episode 93: Reward = 500.0
Episode 94: Reward = 500.0 Episode 94: Reward = 500.0
Episode 95: Reward = 500.0 Episode 95: Reward = 500.0
Episode 96: Reward = 500.0 Episode 96: Reward = 500.0
Episode 97: Reward = 500.0 Episode 97: Reward = 500.0
Episode 98: Reward = 500.0 Episode 98: Reward = 500.0
Episode 99: Reward = 500.0 Episode 99: Reward = 500.0
Episode 100: Reward = 500.0 Episode 100: Reward = 500.0
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment