I have the following code which computes derivatives of the function:
import jax
import jax.numpy as jnp
def f(x):
return jnp.prod(x)
df1 = jax.grad(f)
df2 = jax.jacobian(df1)
df3 = jax.jacobian(df2)
With this, all the partial derivatives are available, for example (with vmap
additionally):
x = jnp.array([[ 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10.],
[11., 12., 13., 14., 15.],
[16., 17., 18., 19., 20.],
[21., 22., 23., 24., 25.],
[26., 27., 28., 29., 30.]])
df3_x0_x2_x4 = jax.vmap(df3)(x)[:, 0, 2, 4]
print(df3_x0_x2_x4)
# [ 8. 63. 168. 323. 528. 783.]
The question is how can I compute df3_x0_x2_x4
only, avoiding all the unnecessary derivative calculations (and leaving f
with a single vector argument)?
The question is how can I compute
df3_x0_x2_x4
only, avoiding all the unnecessary derivative calculations (and leaving f with a single vector argument)?
Essentially, you're asking for a way to compute sparse Hessians and Jacobians; JAX does not have general support for this (see previous issue threads; e.g https://github.com/google/jax/issues/1032).
Edit
In this particular case, though, since you're effectively computing the gradient/jaacobian with respect to a single element per derivative pass, you can do better by just applying the JVP to a single one-hot vector in each transformation. For example:
def deriv(f, x, v):
return jax.jvp(f, [x], [v])[1]
def one_hot(i):
return jnp.zeros(x.shape[1]).at[i].set(1)
df_x0 = lambda x: deriv(f, x, one_hot(0))
df2_x0_x2 = lambda x: deriv(df_x0, x, one_hot(2))
df3_x0_x2_x4 = lambda x: deriv(df2_x0_x2, x, one_hot(4))
print(jax.vmap(df3_x0_x2_x4)(x))
# [ 8. 63. 168. 323. 528. 783.]
Previous answer
If you're willing to relax your "leaving f with a single argument" criterion, you could do something like this:
def f(*x):
return jnp.prod(jnp.asarray(x))
df1 = jax.grad(f, argnums=4)
df2 = jax.jacobian(df1, argnums=2)
df3 = jax.jacobian(df2, argnums=0)
df3_x0_x2_x4 = jax.vmap(df3)(*(x.T))
print(df3_x0_x2_x4)
# [ 8. 63. 168. 323. 528. 783.]
Here rather than computing all gradients and slicing out the result, you are only computing the gradients with respect to the specific three elements you are interested in.