Search code examples
pythonjaxautomatic-differentiation

JAX `custom_vjp` for functions with multiple outputs


In the JAX documentation, custom derivatives for functions with a single output are covered. I'm wondering how to implement custom derivatives for functions with multiple outputs such as this one?

# want to define custom derivative of out_2 with respect to *args
def test_func(*args, **kwargs):
    ...
    return out_1, out_2

Solution

  • You can define custom derivatives for functions with any number of inputs and outputs: just add the appropriate number of elements to the primals and tangents tuples in the custom_jvp rule. For example:

    import jax
    import jax.numpy as jnp
    
    @jax.custom_jvp
    def f(x, y):
      return x * y, x / y
    
    @f.defjvp
    def f_jvp(primals, tangents):
      x, y = primals
      x_dot, y_dot = tangents
      primals_out = f(x, y)
      tangents_out = (x_dot * y + y_dot * x, 
                      x_dot / y - y_dot * x / y ** 2)
      return primals_out, tangents_out
    
    x = jnp.float32(0.5)
    y = jnp.float32(2.0)
    
    jax.jacobian(f, argnums=(0, 1))(x, y)
    # ((Array(2., dtype=float32), Array(0.5, dtype=float32)),
    #  (Array(0.5, dtype=float32), Array(-0.125, dtype=float32)))
    

    Comparing this with the result computed using the standard non-custom derivative rule for the same function shows that the results are equivalent:

    def f2(x, y):
      return x * y, x / y
    
    jax.jacobian(f2, argnums=(0, 1))(x, y)
    # ((Array(2., dtype=float32), Array(0.5, dtype=float32)),
    #  (Array(0.5, dtype=float32), Array(-0.125, dtype=float32)))