In the JAX documentation, custom derivatives for functions with a single output are covered. I'm wondering how to implement custom derivatives for functions with multiple outputs such as this one?
# want to define custom derivative of out_2 with respect to *args
def test_func(*args, **kwargs):
...
return out_1, out_2
You can define custom derivatives for functions with any number of inputs and outputs: just add the appropriate number of elements to the primals
and tangents
tuples in the custom_jvp
rule. For example:
import jax
import jax.numpy as jnp
@jax.custom_jvp
def f(x, y):
return x * y, x / y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primals_out = f(x, y)
tangents_out = (x_dot * y + y_dot * x,
x_dot / y - y_dot * x / y ** 2)
return primals_out, tangents_out
x = jnp.float32(0.5)
y = jnp.float32(2.0)
jax.jacobian(f, argnums=(0, 1))(x, y)
# ((Array(2., dtype=float32), Array(0.5, dtype=float32)),
# (Array(0.5, dtype=float32), Array(-0.125, dtype=float32)))
Comparing this with the result computed using the standard non-custom derivative rule for the same function shows that the results are equivalent:
def f2(x, y):
return x * y, x / y
jax.jacobian(f2, argnums=(0, 1))(x, y)
# ((Array(2., dtype=float32), Array(0.5, dtype=float32)),
# (Array(0.5, dtype=float32), Array(-0.125, dtype=float32)))