Search code examples
pythonjax

Computing gradient using JAX of a function that outputs a list of arrays


I have a function which returns a list of arrays, and I need to find its derivative with respect to a single parameter. For instance, let's say we have

def fun(x):
...
return [a,b,c]

where a,b,c and d are multi-dimensional arrays (for example, 2 by 2 by 2 real arrays). Now I want to obtain [da/dx, db/dx, dc/dx]. By db/dx I mean I want to obtain derivative of each element in the a:222 array with respect to x, so da/dx, db/dx, dc/dx are all 222 arrays.

This is me using JAX differentiation for the first time, and most of the examples I find online are about functions that has scalar output.

From my search, I understand one way to find this is basically get the gradient of each scalar in all these arrays one at a time (probably making it faster using vmap). Is there any other way that is faster? I think JAX.jacobian might do the trick, but I am having hard time finding its documentation to see what does the function does exactly. Any help is very much appreciated.

Now, I have tried JAX.jacobian with simple examples, and it does give me the answer that I expect. This assures me a bit, but I would like to find official documentation or assurance from others that is the right way to do it, and it is doing what I expect it.


Solution

  • You can use jax.jacobian for what you describe. Here is an example:

    import jax
    import jax.numpy as jnp
    
    def f(x):
      a = jnp.full((2, 2), 2) * x
      b = jnp.full((2, 2), 3) * x
      c = jnp.full((2, 2), 4) * x
      return [a, b, c]
    
    da_dx, db_dx, dc_dx = jax.jacobian(f)(1.0)
    
    print(da_dx)
    # [[2. 2.]
    #  [2. 2.]]
    
    print(db_dx)
    # [[3. 3.]
    #  [3. 3.]]
    
    print(dc_dx)
    # [[4. 4.]
    #  [4. 4.]]
    

    jax.jacobian is an alias of jax.jacrev, and you can find the documentation here: https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html