Recently I have been working on self-defined models with self-defined backward function (since the forward process is not implemented via pytorch AD). Say if I have a model of which the forward function outputs two tensors A
and B
. The backward function defines the gradients (dA/dP
and dB/dP
) of parameter P
and the gradients of the whole loss will be composed of dA/dP
and dB/dP
.
Now I want to define a new differentiable function, which is f(A, B) = A * g(B)
(multiply, not convolution), therefore the gradient of f(A, B)
will be (dA/dP) * g(B) + A * g'(B) * (dB/dP)
.
So I figure, since dA/dP
and dB/dP
are already defined, f(A, B)
should be directly differentiable, and there is no need to modify the backward function. My question is that since I am not sure about the above, I wonder whether this is correct and whether I need to modify the backward function (and the forward function, even).
As long as you use differentiable built-in operators, Autograd will track your operations, and a backward call will result in gradient computation.
In your case, you have the following two operations:
C = g(B)
Z = A*C
So we can summarize the computation graph with:
dL/dB <------\
B -----\ \
\ dC/dB
\ \ <--- dL/dC ----\
-> g(B) = C ----\ \
\ dZ/dC
\ \ <--- dL/dZ ---
-> A * C = Z
/ /
/ dZ/dA
A -------------------------/ /
dL/dA <--------------------------/
Your input gradients A.grad
and B.grad
correspond to dL/dB
and dL/dA
, respectively. While, your model gradient will be given by dL/dC
which is actually dL/dg
(by g
, we refer to the parameters for the "g"
model). This quantity will be computed with the chain rule as dL/dC = dL/dz * dZ/dC
.
You can read more in another question: Understanding backpropagation in PyTorch.