Suppose I have two PyTorch Tensor
objects of equal shape:
import torch
x = torch.randn(2, 10)
y = torch.randn(2, 10)
Now, I have a list of indices (of the same length as the first Tensor
axis) which give different starting positions in the second Tensor
axis from which I want to assign values from y
into x
, i.e.,
idxs = [2, 6]
for i, idx in enumerate(idxs):
x[i, idx:] = y[i, idx:]
As above, I can do this with a for loop, but my question is whether there is a more efficient way of doing this without an explicit for loop?
First, create a index tensor on the second dimension of your tensor with
second_dim_indices = torch.arange(x.shape[1])
and turn idxs
into a tensor:
idxs = torch.LongTensor(idxs)
Then, it is possible to compute a mask that's true when tensor index must be modified with :
mask = second_dim_indices.unsqueeze(0) >= idxs.unsqueeze(1)
# in your case =
# tensor([[False, False, True, True, True, True, True, True, True, True],
# [False, False, False, False, False, False, True, True, True, True]])
Note that we must unsqueeze indices and idxs to broadcast the >=
operation.
Finally, use the mask to update x
:
x = y * mask + x * ~mask