Select Git revision
evaluate_reinforce_cartpole.py
Forked from
Dellandrea Emmanuel / MSO_3_4-TD1
Source project has a limited visibility.
evaluate_reinforce_cartpole.py 1.22 KiB
import gymnasium as gym
import torch
from reinforce_cartpole import Policy
def eval_policy(eval_length, policy, env):
# Reset the environment and get the initial observation
observation = env.reset()[0]
rewards = []
for step in range(eval_length):
# sample action from policy
action_probs = policy(torch.from_numpy(observation).float())
action = torch.distributions.Categorical(action_probs).sample()
observation, reward, terminated, truncated, info = env.step(action.numpy())
rewards.append(reward)
# visualize agent behavio
#env.render()
if terminated or truncated:
break
return sum(rewards)
# Create the environment
env = gym.make("CartPole-v1")
policy = Policy()
# load learned policy
policy.load_state_dict(torch.load('reinforce_cartpole.pth', weights_only=True))
policy.eval()
eval_length = env.spec.max_episode_steps
num_evals = 100
number_of_solves = 0
for eval in range(num_evals):
sum_reward = eval_policy(eval_length, policy, env)
print(f"Average reward: {sum_reward}")
if sum_reward >= 195:
number_of_solves += 1
success_rate = number_of_solves / num_evals
print(f"Success rate: {success_rate}")
env.close()