Here is some python code to reproduce my issue:
import torch
n, m = 9, 4
x = torch.arange(0, n * m).reshape(n, m)
print(x.shape)
print(x)
# torch.Size([9, 4])
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11],
# [12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23],
# [24, 25, 26, 27],
# [28, 29, 30, 31],
# [32, 33, 34, 35]])
list_of_indices = [
[],
[2, 3],
[1],
[],
[],
[],
[0, 1, 2, 3],
[],
[0, 3],
]
print(list_of_indices)
for i, indices in enumerate(list_of_indices):
x[i, indices] = -1
print(x)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, -1, -1],
# [ 8, -1, 10, 11],
# [12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23],
# [-1, -1, -1, -1],
# [28, 29, 30, 31],
# [-1, 33, 34, -1]])
I have a list of list of indices. I want to set the indices in x
to a specific value (here -1
) using the indices in list_of_indices
. In this list, each sublist correspond to a row of x
, containing the indices to set to -1
for this row. This can be easily done using a for-loop, but I feel like pytorch would allow to do that much more efficiently.
I tried the following:
x[torch.arange(len(list_of_indices)), list_of_indices] = -1
but it resulted in
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [9], [9, 0]
I tried to find people having the same problem, but the number of questions about indexing tensors is so large that I might have missed it.
To handle this in-place you can simply flatten x
. You should ravel the indices such that they can index x
when flattened. First, gather the indices, then index x.flatten()
(fyi. not a copy).
indices = torch.tensor([i*m+j for i,r in enumerate(list_of_indices) for j in r])
> tensor([ 6, 7, 9, 24, 25, 26, 27, 32, 35])
x.flatten()[indices] = -1
> tensor([[ 0, 1, 2, 3],
[ 4, 5, -1, -1],
[ 8, -1, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23],
[-1, -1, -1, -1],
[28, 29, 30, 31],
[-1, 33, 34, -1]])
You can also use torch.scatter_
, but in that case it is slightly longer to write:
x.flatten().scatter_(0,indices,value=-1).view_as(x)