I now have a pytorch tensor A of dimensions (2000, 1, 360, 3). I'm trying to find all indexes with norms of 0 in the last dimension of this tensor. And replace these positions with the values of the corresponding positions in another tensor B (the same dimension as A).
Example (A, B: (2, 1, 3, 3))
A = [[[[0, 0, 0], # norm == 0
[1, 2, 1],
[0, 1, 0]]],
[[[2, 0, 0],
[0, 0, 0], # norm == 0
[1, 1, 1]]]]
B = [[[[0, 0, 1],
[1, 1, 1],
[0, 1, 0]]],
[[[1, 0, 0],
[0, 1, 1],
[2, 1, 1]]]]
Expected result:
new_A = [[[[0, 0, 1], # <-- replaced
[1, 2, 1],
[0, 1, 0]]],
[[[2, 0, 0],
[0, 1, 1], # <-- replaced
[1, 1, 1]]]]
You can construct a mask depending on the norm of A
along the last dimension and use torch.where
to assemble to desired tensor:
> torch.where(A.norm(dim=-1, keepdim=True).bool(), A, B)
tensor([[[[0., 0., 1.],
[1., 2., 1.],
[0., 1., 0.]]],
[[[2., 0., 0.],
[0., 1., 1.],
[1., 1., 1.]]]])