import gymnasium as gym
import torch
from reinforce_cartpole import Policy
# Create the environment
env = gym.make("CartPole-v1", render_mode="human")

# Reset the environment and get the initial observation
observation = env.reset()[0]

policy = Policy()
# load learned policy
policy.load_state_dict(torch.load('policy.pth', weights_only=True))
policy.eval()

for _ in range(200):
    # sample action from policy
    print(observation)
    print(torch.from_numpy(observation).float())
    action_probs = policy(torch.from_numpy(observation).float())
    action = torch.distributions.Categorical(action_probs).sample()
    # Apply the action to the environment
    # Returns next observation, reward, done signal (indicating
    # if the episode has ended), and an additional info dictionary
    observation, reward, terminated, truncated, info = env.step(action.numpy())
    # Render the environment to visualize the agent's behavior
    env.render()
    print(terminated or truncated)
    if terminated or truncated: 
        # Terminated before max step
        break

env.close()