Search code examples
pytorchcompound-assignment

Difference between tensor addition assignment and assignment in pytorch


I found that pytorch treats tensor assignment and addition assignment differently. Examples are shown below

x = torch.tensor(3.0)
print(id(x))
x = x + 5
print(id(x))

The results are

1647247869184
1647248066816

If we run the following code

x = torch.tensor(3.0)
print(id(x))
x += 5
print(id(x))

, the results are

1647175563712
1647175563712

From these two examples, we could see that using addition assignment does not change the variable address, while addition changes the address. This has effect on training neural network. For instance, in the pytorch tutorial "what is torch.nn really?", there is a piece of code shown below

from IPython.core.debugger import set_trace

lr = 0.5  # learning rate
epochs = 2  # how many epochs to train for

for epoch in range(epochs):
    for i in range((n - 1) // bs + 1):
        #         set_trace()
        start_i = i * bs
        end_i = start_i + bs
        xb = x_train[start_i:end_i]
        yb = y_train[start_i:end_i]
        pred = model(xb)
        loss = loss_func(pred, yb)

        loss.backward()
        with torch.no_grad():
            weights -= weights.grad * lr
            bias -= bias.grad * lr
            weights.grad.zero_()
            bias.grad.zero_()

We could see that in the torch.no_grad() context, minus assignment is used. If we change the minus assignment as normal assignment shown below, the code does not work.

        with torch.no_grad():
            weights = weights - weights.grad * lr
            bias = bias - bias.grad * lr
            weights.grad.zero_()
            bias.grad.zero_()

Now, I know that += or -= should be used if we do not want to change the variable. However, in python, there is no difference between += and =, both of which change variable addresses. Examples are shown bellow:

x = 3
print(id(x))
x += 1
print(id(x))
x = x + 1
print(id(x))
140736084850528
140736084850560
140736084850592

My questions are

  • Why difference between += and = exists in pytorch? Is that design on purpose?
  • What is the benefit to allow the difference exists?

Solution

  • The += is an inplace operation i.e. it modifies the content of original variable without making its copy (retaining same memory address).

    Other examples:

    • x *= 3
    • X[…] = …
    • X.add_(1)

    In weights = weights - weights.grad * lr, it doesn't work as you're creating a new variable (different address, it's just that it also has name weights).

    By the way, in pytorch's optimizers it's implemented like this:

    weights.add_(weights.grad, alpha=-lr)