Search code examples
pythonpython-3.xmachine-learningdeep-learningpytorch

Pass gradients through min function in PyTorch


I have a variable, say, a which has some gradient associated with it from some operations before. Then I have integers b and c which have no gradients. I want to compute the minimum of a, b, and c. MWE is as given below.

import torch

a = torch.tensor([4.], requires_grad=True)  # As an example I have defined a leaf node here, in my program I have an actual variable with gradient
b = 5
c = 6
d = torch.min(torch.tensor([a, b, c]))  # d does not have gradient associated

How can I write this differently so that the gradient from a flows through to d? Thanks.


Solution

  • The problem with your code is the line d = torch.min(torch.tensor([a, b, c]))

    When you compute torch.tensor([a, b, c]), you create a new tensor that has no computational graph to the a, b or c tensors. For example:

    a = torch.tensor([4.], requires_grad=True)
    b = torch.tensor([5.])
    c = torch.tensor([6.])
    d = torch.tensor([a,b,c])
    d.requires_grad
    > False
    

    The solution is to use the min function with the input tensors themselves.

    a = torch.tensor([4.], requires_grad=True)
    b = torch.tensor([5.])
    c = torch.tensor([6.])
    d = a.min(b).min(c)
    d.requires_grad
    > True
    

    Note that the gradient of the min function is 1 for the min value and 0 for all other values. This means that you will lose gradient signal if the value you want to backprop through is not the min.

    a = torch.tensor([4.], requires_grad=True)
    b = torch.tensor([5.])
    d = a.min(b)
    d.backward()
    a.grad
    > tensor([1.])
    
    a = torch.tensor([6.], requires_grad=True)
    b = torch.tensor([5.])
    d = a.min(b)
    d.backward()
    a.grad
    > tensor([0.])