Search code examples
python-3.xreinforcement-learningopenai-gym

OpenAI DQN runtime error how can I fix it?


I am learning reinforcement learning and I want to implement Q-Network to solve the OpenAI Taxi environment problem. I found this code online and I am trying to run the code and I get an error. Below is the code

import numpy as np
import gym
import random

def main():

    # create Taxi environment
    env = gym.make('Taxi-v3')

    # initialize q-table
    state_size = env.observation_space.n
    action_size = env.action_space.n
    qtable = np.zeros((state_size, action_size))

    # hyperparameters
    learning_rate = 0.9
    discount_rate = 0.8
    epsilon = 1.0
    decay_rate= 0.005

   # training variables
   num_episodes = 1000
   max_steps = 99 # per episode

   # training
   for episode in range(num_episodes):

        # reset the environment
        state = env.reset()
        done = False

        for s in range(max_steps):

            # exploration-exploitation tradeoff
            if random.uniform(0,1) < epsilon:
                # explore
                action = env.action_space.sample()
            else:
                # exploit
                action = np.argmax(qtable[state,:])

            # take action and observe reward
            new_state, reward, done, trunc, info = env.step(action)

            # Q-learning algorithm
            qtable[state,action] = qtable[state,action] + learning_rate * int(reward + discount_rate * np.max(qtable[new_state,:]) - qtable[state,action])

            # Update to our new state
            state = new_state

            # if done, finish episode
            if done == True:
                break

        # Decrease epsilon
        epsilon = np.exp(-decay_rate*episode)

    print(f"Training completed over {num_episodes} episodes")
    input("Press Enter to watch trained agent...")

    # watch trained agent
    state = env.reset()
    done = False
    rewards = 0

    for s in range(max_steps):

        print(f"TRAINED AGENT")
        print("Step {}".format(s+1))

        action = np.argmax(qtable[state,:])
        new_state, reward, done, trunc, info = env.step(action)
        rewards += reward
        env.render()
        print(f"score: {rewards}")
        state = new_state

        if done == True:
            break

    env.close()

if __name__ == "__main__":
    main()

When I try to run the above code I get the following error message:

Traceback (most recent call last):

  File "/tmp/ipykernel_2838/974516385.py", line 84, in <module>
main()

  File "/tmp/ipykernel_2838/974516385.py", line 46, in main
    qtable[state,action] = qtable[state,action] + learning_rate * int(reward + 
discount_rate * np.max(qtable[new_state,:]) - qtable[state,action])

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and 
integer or boolean arrays are valid indices

The answer to my code problem suggested that the q[state, action] required that the state, action and new_state be integers. I edited the following parts of the code

Action Code:

# Exploration-Exploitation tradeoff
if(random.uniform(0, 1) < epsilon):
   #Explore
   action = int(env.action_space.sample())
else:
   action = int(np.argmax(qtable[int(state),:]))

Qtable new_state variable Code:

qtable[state, action] = qtable[state, action] + learning_rate * int(reward + discount_rate * np.max(qtable[int(new_state), :])) - qtable[state, action]

State Code:

state = int(new_state)

When I run the first part of the code I get no errors. When I run the first part of the code and the second part of the code "Watch training agent" I get the same error:

Traceback (most recent call last):

  File "/tmp/ipykernel_9211/3085956694.py", line 73, in <module>
main()

  File "/tmp/ipykernel_9211/3085956694.py", line 40, in main
qtable[state, action] = qtable[state, action] + learning_rate * int(reward + discount_rate * np.max(qtable[int(new_state), :])) - qtable[state, action]

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

Why I am getting this error and what code do I need to add to fix this error?


Solution

  • The problem is with the following line of code

    state = env.reset()
    

    the env.reset() function returns a tuple. In the code we need the first value of the tuple which is zero to reset the environment. In order to get the first value we need to modify the previous line of code to

    state, _ = env.reset()
    

    Now we have the value zero stored in the state variable.