Search code examples
pythonpytorch

Pytorch index with Tensor


I have a 2-dimentional tensor arr with 0 as all the entries. I have a second tensor idx. I want to make all entries in arr with the indices in idx into 1.

arr = torch.zeros(size = (2,10))
idx = torch.Tensor([
    [0,2],
    [4,5]
])
arr[idx] = 1 #This doesn't  work
print(arr)

The output should look like this:

tensor([[1., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0., 0., 0.]])

I had high confidence that I would definitely find someone else ask this in SO, however I couldn't find one. I hope it isn't duplicate.


Solution

  • Use scatter() along dim=1 or the innermost dimension in this case i.e. dim=-1. Note that in place of src tensor, I just passed the constant value 1.

    In [31]: arr = torch.zeros(size=(2, 10))
     
    In [32]: idx = torch.tensor([
         ...:     [0, 2],
         ...:     [4, 5]
         ...:     ])
     
    In [33]: torch.scatter(arr, 1, idx, 1)
    Out[33]: 
     tensor([[1., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 1., 1., 0., 0., 0., 0.]])
    
    In [34]: torch.scatter(arr, -1, idx, 1)
    Out[34]: 
     tensor([[1., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 1., 1., 0., 0., 0., 0.]])