I have a tensor
import torch
a = torch.randn(1, 3, requires_grad=True)
print('a: ', a)
>>> a: tensor([[0.0200, 1.00200, -4.2000]], requires_grad=True)
And a mask
mask = torch.zeros_like(a)
mask[0][0] = 1
I want to mask my tensor a
without propagating the gradients to my mask tensor (in my real case it has a gradient). I tried to the following
with torch.no_grad():
b = a * mask
print('b: ', b)
>>> b: tensor([[0.0200, 0.0000, -0.0000]])
But it removes the gradient entirely from my tensor. What is the correct way to do it?
You can call detach
on the mask tensor to remove it from the gradient chain.
a = torch.randn(1, 3, requires_grad=True)
mask = torch.tensor([[1., 0., 0.]], requires_grad=True)
mask_no_grad = mask.detach()
b = a * mask_no_grad
print(b)
> tensor([[0.3871, 0.0000, -0.0000]], grad_fn=<MulBackward0>)