Search code examples
pythonpytorch

how the code '-input[range(target.shape[0]),target]' works?


I'm learing pytorch.Reading the official tutorial,I met the preplexing code. input is a tensor, so is target.

def nll(input,target):
    return -input[range(target.shape[0]),target].mean()

And the pred is:
pred

target is:
target

the '-input[range(target.shape[0]),target]' is:
'input[range(target.shape[0],target]'

Output shows this is not substracting target from input or merging two tensors


Solution

  • The code input[range(target.shape[0]), target] simply picks, from each row i of input the element at column indicated by the corresponding element of target, that is target[i].
    In other words, if out = input[range(target.shape[0]), target] then out[i] = input[i, target[i]].

    This is very similar to torch.gather.