Search code examples
pytorchtorch

Torch: Update tensor with non-zero elements


suppose I have:

>>> a = torch.tensor([1, 2, 3, 0, 0, 1])
>>> b = torch.tensor([0, 1, 3, 3, 0, 0])

I want to update b with elements in a if it's not zero. How can I beneficently do that?

Expected:

>>> b = torch.tensor([1, 2, 3, 3, 0, 1])

Solution

  • To add to the previous answer and for more simplicity you can do it by one line of code:

    b = torch.where(a!=0,a, b)
    

    Output:

    tensor([1, 2, 3, 3, 0, 1])