I have a function which returns a list of arrays, and I need to find its derivative with respect to a single parameter. For instance, let's say we have
def fun(x):
...
return [a,b,c]
where a,b,c and d are multi-dimensional arrays (for example, 2 by 2 by 2 real arrays). Now I want to obtain [da/dx, db/dx, dc/dx]. By db/dx I mean I want to obtain derivative of each element in the a:222 array with respect to x, so da/dx, db/dx, dc/dx are all 222 arrays.
This is me using JAX differentiation for the first time, and most of the examples I find online are about functions that has scalar output.
From my search, I understand one way to find this is basically get the gradient of each scalar in all these arrays one at a time (probably making it faster using vmap). Is there any other way that is faster? I think JAX.jacobian might do the trick, but I am having hard time finding its documentation to see what does the function does exactly. Any help is very much appreciated.
Now, I have tried JAX.jacobian with simple examples, and it does give me the answer that I expect. This assures me a bit, but I would like to find official documentation or assurance from others that is the right way to do it, and it is doing what I expect it.
You can use jax.jacobian
for what you describe. Here is an example:
import jax
import jax.numpy as jnp
def f(x):
a = jnp.full((2, 2), 2) * x
b = jnp.full((2, 2), 3) * x
c = jnp.full((2, 2), 4) * x
return [a, b, c]
da_dx, db_dx, dc_dx = jax.jacobian(f)(1.0)
print(da_dx)
# [[2. 2.]
# [2. 2.]]
print(db_dx)
# [[3. 3.]
# [3. 3.]]
print(dc_dx)
# [[4. 4.]
# [4. 4.]]
jax.jacobian
is an alias of jax.jacrev
, and you can find the documentation here: https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html