Search code examples
pythonpytorch

How Pytorch Tensor get the index of specific value


With python lists, we can do:

a = [1, 2, 3]
assert a.index(2) == 1

How can a pytorch tensor find the .index() directly?


Solution

  • I think there is no direct translation from list.index() to a pytorch function. However, you can achieve similar results using tensor==number and then the nonzero() function. For example:

    t = torch.Tensor([1, 2, 3])
    print ((t == 2).nonzero(as_tuple=True)[0])
    

    This piece of code returns

    1

    [torch.LongTensor of size 1x1]