Search code examples

Pytorch advanced indexing with list of lists as indices

Here is some python code to reproduce my issue:

import torch

n, m = 9, 4

x = torch.arange(0, n * m).reshape(n, m)
# torch.Size([9, 4])
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11],
#         [12, 13, 14, 15],
#         [16, 17, 18, 19],
#         [20, 21, 22, 23],
#         [24, 25, 26, 27],
#         [28, 29, 30, 31],
#         [32, 33, 34, 35]])

list_of_indices = [
    [2, 3],
    [0, 1, 2, 3],
    [0, 3],

for i, indices in enumerate(list_of_indices):
    x[i, indices] = -1

# tensor([[ 0,  1,  2,  3],
#         [ 4,  5, -1, -1],
#         [ 8, -1, 10, 11],
#         [12, 13, 14, 15],
#         [16, 17, 18, 19],
#         [20, 21, 22, 23],
#         [-1, -1, -1, -1],
#         [28, 29, 30, 31],
#         [-1, 33, 34, -1]])

I have a list of list of indices. I want to set the indices in x to a specific value (here -1) using the indices in list_of_indices. In this list, each sublist correspond to a row of x, containing the indices to set to -1 for this row. This can be easily done using a for-loop, but I feel like pytorch would allow to do that much more efficiently.

I tried the following:

x[torch.arange(len(list_of_indices)), list_of_indices] = -1

but it resulted in

IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [9], [9, 0]

I tried to find people having the same problem, but the number of questions about indexing tensors is so large that I might have missed it.


  • To handle this in-place you can simply flatten x. You should ravel the indices such that they can index x when flattened. First, gather the indices, then index x.flatten() (fyi. not a copy).

    indices = torch.tensor([i*m+j for i,r in enumerate(list_of_indices) for j in r])
    > tensor([ 6,  7,  9, 24, 25, 26, 27, 32, 35])
    x.flatten()[indices] = -1
    > tensor([[ 0,  1,  2,  3],
              [ 4,  5, -1, -1],
              [ 8, -1, 10, 11],
              [12, 13, 14, 15],
              [16, 17, 18, 19],
              [20, 21, 22, 23],
              [-1, -1, -1, -1],
              [28, 29, 30, 31],
              [-1, 33, 34, -1]])

    You can also use torch.scatter_, but in that case it is slightly longer to write:
