Search code examples
pythonpytorchreinforcement-learning

What is the purpose of [np.arange(0, self.batch_size), action] after the neural network?


I followed a PyTorch tutorial to learn reinforcement learning(TRAIN A MARIO-PLAYING RL AGENT) but I am confused about the following code:

current_Q = self.net(state, model="online")[np.arange(0, self.batch_size), action] # Q_online(s,a)

What's the purpose of [np.arange(0, self.batch_size), action] after the neural network?(I know that TD_estimate takes in state and action, just confused about this on the programming side) What is this usage(put a list after self.net)?

More related code referenced from the tutorial:

class MarioNet(nn.Module):

def __init__(self, input_dim, output_dim):
    super().__init__()
    c, h, w = input_dim

    if h != 84:
        raise ValueError(f"Expecting input height: 84, got: {h}")
    if w != 84:
        raise ValueError(f"Expecting input width: 84, got: {w}")

    self.online = nn.Sequential(
        nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
        nn.ReLU(),
        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
        nn.ReLU(),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(3136, 512),
        nn.ReLU(),
        nn.Linear(512, output_dim),
    )

    self.target = copy.deepcopy(self.online)

    # Q_target parameters are frozen.
    for p in self.target.parameters():
        p.requires_grad = False

def forward(self, input, model):
    if model == "online":
        return self.online(input)
    elif model == "target":
        return self.target(input)

self.net:

self.net = MarioNet(self.state_dim, self.action_dim).float()

Thanks for any help!


Solution

  • Essentially, what happens here is that the output of the net is being sliced to get the desired part of the Q table.

    The (somewhat confusing) index of [np.arange(0, self.batch_size), action] indexes each axis. So, for axis with index 1, we pick the item indicated by action. For index 0, we pick all items between 0 and self.batch_size.

    If self.batch_size is the same as the length of dimension 0 of this array, then this slice can be simplified to [:, action] which is probably more familiar to most users.