Skip to content
Snippets Groups Projects
Select Git revision
  • fa4e312ef63f69815684d7fd4ecef58af3697c4c
  • main default protected
2 results

evaluate_reinforce_cartpole.py

Blame
  • Forked from Dellandrea Emmanuel / MSO_3_4-TD1
    Source project has a limited visibility.
    evaluate_reinforce_cartpole.py 1.22 KiB
    import gymnasium as gym
    import torch
    from reinforce_cartpole import Policy
    
    def eval_policy(eval_length, policy, env):
        # Reset the environment and get the initial observation
        observation = env.reset()[0]
        rewards = []
    
        for step in range(eval_length):
            # sample action from policy
            action_probs = policy(torch.from_numpy(observation).float())
            action = torch.distributions.Categorical(action_probs).sample()
            observation, reward, terminated, truncated, info = env.step(action.numpy())
            rewards.append(reward)
            # visualize agent behavio
            #env.render()
            if terminated or truncated: 
                break
        return sum(rewards)
    # Create the environment
    env = gym.make("CartPole-v1")
    
    policy = Policy()
    # load learned policy
    policy.load_state_dict(torch.load('reinforce_cartpole.pth', weights_only=True))
    policy.eval()
    
    eval_length = env.spec.max_episode_steps
    num_evals = 100
    number_of_solves = 0
    for eval in range(num_evals):
        sum_reward = eval_policy(eval_length, policy, env)
        print(f"Average reward: {sum_reward}")
        if sum_reward >= 195:
            number_of_solves += 1
        
    success_rate = number_of_solves / num_evals
    print(f"Success rate: {success_rate}")
    
    env.close()