Search code examples
pytorchtensortorch

Pytorch, get the index of the first 0 in as mask?


I have a tensor that looks like: (1, 1, 1, 1, 1, 1, 1, 1, 0, 0). I want to get the index where the first zero appears. What would the be best way do this?


Solution

  • Not the best usage of argmin but it should work here I think:

    >>> torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 0, 0]).argmin()
    tensor(8)