Search code examples
pytorchautomatic-differentiation

Whether there is any need to modify the backward function in pytorch?


Recently I have been working on self-defined models with self-defined backward function (since the forward process is not implemented via pytorch AD). Say if I have a model of which the forward function outputs two tensors A and B. The backward function defines the gradients (dA/dP and dB/dP) of parameter P and the gradients of the whole loss will be composed of dA/dP and dB/dP.

Now I want to define a new differentiable function, which is f(A, B) = A * g(B) (multiply, not convolution), therefore the gradient of f(A, B) will be (dA/dP) * g(B) + A * g'(B) * (dB/dP).

So I figure, since dA/dP and dB/dP are already defined, f(A, B) should be directly differentiable, and there is no need to modify the backward function. My question is that since I am not sure about the above, I wonder whether this is correct and whether I need to modify the backward function (and the forward function, even).


Solution

  • As long as you use differentiable built-in operators, Autograd will track your operations, and a backward call will result in gradient computation.

    In your case, you have the following two operations:

    C = g(B)
    Z = A*C
    

    So we can summarize the computation graph with:

    dL/dB <------\    
      B   -----\  \ 
                \ dC/dB 
                 \  \ <--- dL/dC ----\
                  -> g(B) = C  ----\  \
                                    \ dZ/dC
                                     \  \ <--- dL/dZ ---
                                      -> A * C = Z
                                     /  /
                                    / dZ/dA
      A   -------------------------/  /
    dL/dA <--------------------------/
    

    Your input gradients A.grad and B.grad correspond to dL/dB and dL/dA, respectively. While, your model gradient will be given by dL/dC which is actually dL/dg (by g, we refer to the parameters for the "g" model). This quantity will be computed with the chain rule as dL/dC = dL/dz * dZ/dC.

    You can read more in another question: Understanding backpropagation in PyTorch.