I'd like to force to zero all elements of a vector which are below a certain threshold. And I'd like to do it so that I can still propagate gradient through non-zero ones.
For example, in theano I could write:
B = theano.tensor.switch(A < .1, 0, A)
Is there a solution for that in pytorch?
As of pytorch 0.4+, you can do it easily with torch.where
(see doc,Merged PR)
It is as easy as in Theano. See yourself with an example:
import torch
from torch.autograd import Variable
x = Variable(torch.arange(0,4), requires_grad=True) # x = [0 1 2 3]
zeros = Variable(torch.zeros(*x.shape)) # zeros = [0 0 0 0]
y = x**2 # y = [0 1 4 9]
z = torch.where(y < 5, zeros, y) # z = [0 0 0 9]
# dz/dx = (dz/dy)(dy/dx) = (y < 5)(0) + (y ≥ 5)(2x) = 2x(x**2 ≥ 5)
z.backward(torch.Tensor([1.0]))
x.grad # (dz/dx) = [0 0 0 6]