Search code examples
dictionarylambdareinforcement-learning

Python, square brackets after lambda call and dictionary comprehension


I am working through Miguel Morales Deep Reinforcement Learning book and I have come across some syntax that I am unfamilar with. I have looked at the tutorials for dictionary comprehension and lambda functions and I am yet to find what the square brackets does at the end of the lamda call. Can anyone help?

def policy_improvement(value, mdp, gamma=1.0):
''' Performs improvement of a given policy by evaluating the actions for each state and choosing greedily. '''

Q = np.zeros((len(mdp), len(mdp[0])), dtype=np.float64)

for state in range(len(mdp)):
    for action in range(len(mdp[state])):
        for transition_prob, state_prime, reward, done in mdp[state][action]:
            Q[state][action] += transition_prob * (reward + gamma * value[state_prime] * (not done)) # Update the Q value for each action.

new_policy_pie = lambda state: {state:action for state, action in enumerate(np.argmax(Q, axis=1))}[state]

return new_policy_pie

Solution

  • The lambda is made of two parts:

    # 1. Create a map of state to action
    d = {state:action for state, action in enumerate(np.argmax(Q, axis=1))}
    
    # 2. Return the value for argument `state`
    d[state]
    

    So the [state] bit is part of the lambda, it's the dictionary selector. The intent might have been to write a fail-safe lambda in case the state does not exist, but then they should use .get(state) instead of [].

    So in the end, your code above could be replaced with:

    lambda state: np.argmax(Q, axis=1)[state]