Search code examples
deep-learningtheanopytorch

Is there analog of theano.tensor.switch in pytorch?


I'd like to force to zero all elements of a vector which are below a certain threshold. And I'd like to do it so that I can still propagate gradient through non-zero ones.

For example, in theano I could write:

B = theano.tensor.switch(A < .1, 0, A)

Is there a solution for that in pytorch?


Solution

  • As of pytorch 0.4+, you can do it easily with torch.where(see doc,Merged PR)

    It is as easy as in Theano. See yourself with an example:

    import torch
    from torch.autograd import Variable
    
    x = Variable(torch.arange(0,4), requires_grad=True) # x     = [0 1 2 3]
    zeros = Variable(torch.zeros(*x.shape))             # zeros = [0 0 0 0]
    
    y = x**2                         # y = [0 1 4 9]
    z = torch.where(y < 5, zeros, y) # z = [0 0 0 9]
    
    # dz/dx = (dz/dy)(dy/dx) = (y < 5)(0) + (y ≥ 5)(2x) = 2x(x**2 ≥ 5) 
    z.backward(torch.Tensor([1.0])) 
    x.grad # (dz/dx) = [0 0 0 6]