Search code examples
pythontensorflowjaxautomatic-differentiation

jax automatic differentiation


I have the following three functions implements in JAX.

def helper_1(params1):
   ...calculations...
   return z

def helper_2(z, params2):
   ...calculations...
   return y

def main(params1, params2):
   z = helper_1(params1)
   y = helper_2(z, params2)
   return z,y

I am interested in the partial derivatives of the output from main, i.e. z and y, with respect to both params1 and params2. As params1 and params2 are low dimensional and z and y are high dimensional, I am using the jax.jacfwd function.

When calling

jax.jacfwd(main,argnums=(0,1))(params1,params2)

Jax computes the derivatives of z with respect to params1 (and params2, which in this case is just a bunch of zeros). My question is: does Jax recompute dz/d_param1 for the derivatives of y with respect to params1 and params2, or does it somehow figure out this has already been computed?

I don't know if this is relevant, but the 'helper_1' function contains functions from the TensorFlow library for Jax. Thanks!


Solution

  • In general, in the situation you describe JAX's forward-mode autodiff approach will re-use the derivative of z when computing the derivative of y. If you wish, you can confirm this by looking at the jaxpr of your differentiated function:

    print(jax.make_jaxpr(jax.jacfwd(main, (0, 1)))(params1, params2))
    

    Though if your function is more than moderately complicated, the output might be hard to understand.

    As a general note, though, JAX's autodiff implementation does tend to produce a small number of unnecessary or duplicated computations. As a simple example consider this:

    import jax
    print(jax.make_jaxpr(jax.grad(jax.lax.sin))(1.0))
    # { lambda ; a:f32[]. let
    #     _:f32[] = sin a
    #     b:f32[] = cos a
    #     c:f32[] = mul 1.0 b
    #   in (c,) }
    

    Here the primal value sin(a) is computed even though it is never used in computing the final output.

    In practice this can be addressed by wrapping your computation in jit, in which case the XLA compiler takes care of optimization, including dead code elimination and de-duplication when applicable:

    result = jit(jax.jacfwd(main, (0, 1)))(params1, params2)