Search code examples
pytorchpass-by-reference

clarify pytorch tensor as reference vs value


Why does the top code a = mat[0,0]; a = torch.tensor([99]) not change mat but the bottom code row = mat[0,:]; row[0] = torch.tensor([99]) does?

>>> mat = torch.ones(2,3); print(mat)
tensor([[1., 1., 1.],
        [1., 1., 1.]])
>>> a = mat[0,0]
>>> a = torch.tensor([99]); print(mat)
tensor([[1., 1., 1.],
        [1., 1., 1.]])
>>> row = mat[0,:]
>>> row[0] = torch.tensor([99]); print(mat)
tensor([[99.,  1.,  1.],
        [ 1.,  1.,  1.]])

Solution

  • When you run a = torch.tensor([99]), you change the reference of the a variable from the mat tensor to the new torch.tensor([99]). The assignment here is changing what the variable a means.

    When you run row[0] = torch.tensor([99]), the row reference stays the same, but the specific item row[0] is changed. Because row shares memory with mat, mat is changed as well. The assignment here is not changing the variable row, but is changing a specific element of row.

    You can compare the two assignments more directly.

    mat = torch.ones(2,3)
    row = mat[0,:]
    row[0] = torch.tensor([99]) # here we change element `0` of `row`
    print(mat) # mat is changed
    
    mat = torch.ones(2,3)
    row = mat[0,:]
    row = torch.tensor([99]) # here we change the variable `row` without changing specific elements
    print(mat) # mat is unchanged