I am working through Miguel Morales Deep Reinforcement Learning book and I have come across some syntax that I am unfamilar with. I have looked at the tutorials for dictionary comprehension and lambda functions and I am yet to find what the square brackets does at the end of the lamda call. Can anyone help?
def policy_improvement(value, mdp, gamma=1.0):
''' Performs improvement of a given policy by evaluating the actions for each state and choosing greedily. '''
Q = np.zeros((len(mdp), len(mdp[0])), dtype=np.float64)
for state in range(len(mdp)):
for action in range(len(mdp[state])):
for transition_prob, state_prime, reward, done in mdp[state][action]:
Q[state][action] += transition_prob * (reward + gamma * value[state_prime] * (not done)) # Update the Q value for each action.
new_policy_pie = lambda state: {state:action for state, action in enumerate(np.argmax(Q, axis=1))}[state]
return new_policy_pie
The lambda is made of two parts:
# 1. Create a map of state to action
d = {state:action for state, action in enumerate(np.argmax(Q, axis=1))}
# 2. Return the value for argument `state`
d[state]
So the [state]
bit is part of the lambda, it's the dictionary selector.
The intent might have been to write a fail-safe lambda in case the state does not exist, but then they should use .get(state)
instead of []
.
So in the end, your code above could be replaced with:
lambda state: np.argmax(Q, axis=1)[state]