Search code examples
pythonjax

multivariate derivatives in jax - efficiency question


I have the following code which computes derivatives of the function:

import jax
import jax.numpy as jnp


def f(x):
    return jnp.prod(x)


df1 = jax.grad(f)
df2 = jax.jacobian(df1)
df3 = jax.jacobian(df2)

With this, all the partial derivatives are available, for example (with vmap additionally):

x = jnp.array([[ 1.,  2.,  3.,  4.,  5.],
               [ 6.,  7.,  8.,  9., 10.],
               [11., 12., 13., 14., 15.],
               [16., 17., 18., 19., 20.],
               [21., 22., 23., 24., 25.],
               [26., 27., 28., 29., 30.]])
df3_x0_x2_x4 = jax.vmap(df3)(x)[:, 0, 2, 4]
print(df3_x0_x2_x4)
# [  8.  63. 168. 323. 528. 783.]

The question is how can I compute df3_x0_x2_x4 only, avoiding all the unnecessary derivative calculations (and leaving f with a single vector argument)?


Solution

  • The question is how can I compute df3_x0_x2_x4 only, avoiding all the unnecessary derivative calculations (and leaving f with a single vector argument)?

    Essentially, you're asking for a way to compute sparse Hessians and Jacobians; JAX does not have general support for this (see previous issue threads; e.g https://github.com/google/jax/issues/1032).

    Edit

    In this particular case, though, since you're effectively computing the gradient/jaacobian with respect to a single element per derivative pass, you can do better by just applying the JVP to a single one-hot vector in each transformation. For example:

    def deriv(f, x, v):
      return jax.jvp(f, [x], [v])[1]
    
    def one_hot(i):
      return jnp.zeros(x.shape[1]).at[i].set(1)
    
    df_x0 = lambda x: deriv(f, x, one_hot(0))
    df2_x0_x2 = lambda x: deriv(df_x0, x, one_hot(2))
    df3_x0_x2_x4 = lambda x: deriv(df2_x0_x2, x, one_hot(4))
    print(jax.vmap(df3_x0_x2_x4)(x))
    # [  8.  63. 168. 323. 528. 783.]
    

    Previous answer

    If you're willing to relax your "leaving f with a single argument" criterion, you could do something like this:

    def f(*x):
      return jnp.prod(jnp.asarray(x))
    
    df1 = jax.grad(f, argnums=4)
    df2 = jax.jacobian(df1, argnums=2)
    df3 = jax.jacobian(df2, argnums=0)
    
    df3_x0_x2_x4 = jax.vmap(df3)(*(x.T))
    print(df3_x0_x2_x4)
    # [  8.  63. 168. 323. 528. 783.]
    

    Here rather than computing all gradients and slicing out the result, you are only computing the gradients with respect to the specific three elements you are interested in.