Search code examples
pythonpytorchtensor

Replace the row with a norm of 0 in the tensor with the corresponding row in another tensor


I now have a pytorch tensor A of dimensions (2000, 1, 360, 3). I'm trying to find all indexes with norms of 0 in the last dimension of this tensor. And replace these positions with the values of the corresponding positions in another tensor B (the same dimension as A).

Example (A, B: (2, 1, 3, 3))

A = [[[[0, 0, 0],  # norm == 0
       [1, 2, 1],
       [0, 1, 0]]],
     [[[2, 0, 0],
       [0, 0, 0],  # norm == 0
       [1, 1, 1]]]]
B = [[[[0, 0, 1],
       [1, 1, 1],
       [0, 1, 0]]],
     [[[1, 0, 0],
       [0, 1, 1],
       [2, 1, 1]]]]

Expected result:

new_A = [[[[0, 0, 1],   # <-- replaced
           [1, 2, 1],
           [0, 1, 0]]],
         [[[2, 0, 0],
           [0, 1, 1],   # <-- replaced
           [1, 1, 1]]]]

Solution

  • You can construct a mask depending on the norm of A along the last dimension and use torch.where to assemble to desired tensor:

    > torch.where(A.norm(dim=-1, keepdim=True).bool(), A, B)
    tensor([[[[0., 0., 1.],
              [1., 2., 1.],
              [0., 1., 0.]]],
    
    
            [[[2., 0., 0.],
              [0., 1., 1.],
              [1., 1., 1.]]]])