Search code examples
pythonpytorchin-place

Inplace add with list selectors


I'm facing a somehow inconsistent behaviour of pythorch according to weither the index is a list or an integer. Take a look at this snippet of code :

# First example, integer selector ==> Ok
t = torch.tensor([[0, 1], [1, 0]])
t[0, 0].add_(10)
print(t)
tensor([[10,  1],
        [ 1,  0]])

# Second example, list selector ==> ???
t = torch.tensor([[0, 1], [1, 0]])
t[[0], [0]].add_(10) # notice the list selector
print(t)
tensor([[0, 1],
        [1, 0]])

#Third example, list selector with inplace add operator ==> Ok
t = torch.tensor([[0, 1], [1, 0]])
t[[0], [0]] += 10
print(t)
tensor([[10,  1],
        [ 1,  0]])

I can't understand why pytorch was unable to update tin the second example !


Solution

  • See the difference between the two indexing:

    In []: t[0, 0].shape
    
    Out[]: torch.Size([])
    
    In []: t[[0], [0]].shape
    
    Out[]: torch.Size([1])
    

    When you index directly the (0, 0)th element of t you have a reference to that entry and you can inplace add_ to it. The shape of t[0,0] is [] - that is you get a scalar back - the content of the (0,0) entry.
    However, when you use list indices ([0], [0]) you get back a 1-dim tensor, shape is [1]. That is, you get a copy of a sub-tensor of t. You then inplace add_ to that copy of sub-tensor, you have no effect over the original t:

    In []: r = t[[0], [0]].add_(10)
    In []: t
    Out[]:
    tensor([[0, 1],
            [1, 0]])
    
    In []: r
    Out[]: tensor([10])
    

    Perhaps you want to look into index_add_() to accomplish your task.

    Update When you assign to t using list indices, you are not creating a copy (it makes no sense. So,

    t[[0], [0]] += 10
    

    Translates to

    t[[0], [0]] = t[[0], [0]] + 10
    

    That is, on the right hand side we have a copy of the (0,0) sub-tensor of t and we are adding 10 to that sub-tensor, resulting with a shape [1] tensor with value [10]. On the left hand side we assign this [10] to the (0,0) sub-tensor of t (not to a copy of it - it makes no sense).
    Therefore the output of t[[0], [0]] += 10 is

    tensor([[10,  1],
            [ 1,  0]])