Search code examples

PPO model learns well then predicts only negative actions

I'm using openai's gymnasium python package to create a PPO model to play a simple grid based game, similar to gym's GridWorld example. Most actions will result in a positive reward. Usually there is only one action that will result in a negative reward.

During the learning phase, I can by printing out in the environment's step() function, that the model is doing pretty well. It rarely chooses the actions that would have negative rewards.

When I try to test the model after and predict on a new game, it freaks out and chooses a few good actions followed by only choosing the only action that gives a negative reward. Once it finds the bad action, it sticks with it until the end.

Is there a bug in the code for testing/using the model to predict?

env = GameEnv()
obs = env.reset()
model = PPO("MultiInputPolicy", env, verbose=1)

obs = env.reset()

for i in range(50):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(int(action))

    if done:
        obs = env.reset()

Sample output from learning:

action, reward = 2, 1
action, reward = 3, 1
action, reward = 2, 5
action, reward = 0, 1
action, reward = 0, 9
action, reward = 1, 1
action, reward = 3, 1
action, reward = 3, -5
action, reward = 2, 1

Sample output from testing:

action, reward = 0, 1
action, reward = 1, 5
action, reward = 2, 1
action, reward = 0, 1
action, reward = 0, -5
action, reward = 0, -5
action, reward = 0, -5
action, reward = 0, -5
action, reward = 0, -5
action, reward = 0, -5
action, reward = 0, -5
action, reward = 0, -5


  • The problem is in specifying deterministic=True on model.predict().

    When the deterministic flag is true, the model always picks the same action for a given state, explaining the fixation on a negative action. Flipping this flag to false resulted in the expected behavior.

    Changing this line: action, _states = model.predict(obs, deterministic=True)

    to this: action, _states = model.predict(obs, deterministic=False) fixes the problem.