I have the following three functions implements in JAX.
def helper_1(params1):
...calculations...
return z
def helper_2(z, params2):
...calculations...
return y
def main(params1, params2):
z = helper_1(params1)
y = helper_2(z, params2)
return z,y
I am interested in the partial derivatives of the output from main, i.e. z
and y
, with respect to both params1
and params2
. As params1
and params2
are low dimensional and z
and y
are high dimensional, I am using the jax.jacfwd function.
When calling
jax.jacfwd(main,argnums=(0,1))(params1,params2)
Jax computes the derivatives of z
with respect to params1
(and params2
, which in this case is just a bunch of zeros). My question is: does Jax recompute dz/d_param1 for the derivatives of y
with respect to params1
and params2
, or does it somehow figure out this has already been computed?
I don't know if this is relevant, but the 'helper_1' function contains functions from the TensorFlow library for Jax. Thanks!
In general, in the situation you describe JAX's forward-mode autodiff approach will re-use the derivative of z
when computing the derivative of y
. If you wish, you can confirm this by looking at the jaxpr
of your differentiated function:
print(jax.make_jaxpr(jax.jacfwd(main, (0, 1)))(params1, params2))
Though if your function is more than moderately complicated, the output might be hard to understand.
As a general note, though, JAX's autodiff implementation does tend to produce a small number of unnecessary or duplicated computations. As a simple example consider this:
import jax
print(jax.make_jaxpr(jax.grad(jax.lax.sin))(1.0))
# { lambda ; a:f32[]. let
# _:f32[] = sin a
# b:f32[] = cos a
# c:f32[] = mul 1.0 b
# in (c,) }
Here the primal value sin(a)
is computed even though it is never used in computing the final output.
In practice this can be addressed by wrapping your computation in jit
, in which case the XLA compiler takes care of optimization, including dead code elimination and de-duplication when applicable:
result = jit(jax.jacfwd(main, (0, 1)))(params1, params2)