Search code examples
pythonpytorchtensor

Can you make assignments between PyTorch tensors using ragged indices without a for loop?


Suppose I have two PyTorch Tensor objects of equal shape:

import torch

x = torch.randn(2, 10)
y = torch.randn(2, 10)

Now, I have a list of indices (of the same length as the first Tensor axis) which give different starting positions in the second Tensor axis from which I want to assign values from y into x, i.e.,

idxs = [2, 6]
for i, idx in enumerate(idxs):
    x[i, idx:] = y[i, idx:]

As above, I can do this with a for loop, but my question is whether there is a more efficient way of doing this without an explicit for loop?


Solution

  • First, create a index tensor on the second dimension of your tensor with

    second_dim_indices = torch.arange(x.shape[1])
    

    and turn idxs into a tensor:

    idxs = torch.LongTensor(idxs)
    

    Then, it is possible to compute a mask that's true when tensor index must be modified with :

    mask = second_dim_indices.unsqueeze(0) >= idxs.unsqueeze(1)
    # in your case =
    #  tensor([[False, False,  True,  True,  True,  True,  True,  True,  True,   True],
    #          [False, False, False, False, False, False,  True,  True,  True,  True]])
    

    Note that we must unsqueeze indices and idxs to broadcast the >= operation.

    Finally, use the mask to update x:

    x = y * mask + x * ~mask