Search code examples
pytorchtensor

How to mask a tensor without losing the gradient?


I have a tensor

import torch
a = torch.randn(1, 3, requires_grad=True)
print('a: ', a)
>>> a:  tensor([[0.0200, 1.00200, -4.2000]], requires_grad=True)

And a mask

mask = torch.zeros_like(a)
mask[0][0] = 1

I want to mask my tensor a without propagating the gradients to my mask tensor (in my real case it has a gradient). I tried to the following

with torch.no_grad():
    b = a * mask
    print('b: ', b)
    >>> b:  tensor([[0.0200, 0.0000, -0.0000]])

But it removes the gradient entirely from my tensor. What is the correct way to do it?


Solution

  • You can call detach on the mask tensor to remove it from the gradient chain.

    a = torch.randn(1, 3, requires_grad=True)
    
    mask = torch.tensor([[1., 0., 0.]], requires_grad=True)
    
    mask_no_grad = mask.detach()
    
    b = a * mask_no_grad
    print(b)
    > tensor([[0.3871, 0.0000, -0.0000]], grad_fn=<MulBackward0>)