I have a joint cumulative density function defined in python as a function of a jax array and returning a single value. Something like:
def cumulative(inputs: array) -> float:
...
To have the gradient, I know I can just do grad(cumulative)
, but that is only giving me the first-order partial derivatives of cumulative with respect to the input variables.
Instead, what I would like to do is to compute is this, assuming F is my function and f the joint probability density function:
The order of the partial derivation doesn't matter.
So, I have several questions:
JAX generally treats gradients as being with respect to individual arguments, not elements within arguments. Within this context, one built-in function that is similar to what you want to do (but not exactly the same) is jax.hessian
, which computes the hessian matrix of second derivatives; for example:
import jax
import jax.numpy as jnp
def f(x):
return jnp.prod(x ** 2)
x = jnp.arange(1.0, 4.0)
print(jax.hessian(f)(x))
# [[72. 72. 48.]
# [72. 18. 24.]
# [48. 24. 8.]]
For higher-order derivatives with respect to individual elements of the array, I think you'll have to manually nest the gradients. You could do so with a helper function that looks something like this:
def grad_all(f):
def gradfun(x):
args = tuple(x)
f_args = lambda *args: f(jnp.array(args))
for i in range(len(args)):
f_args = jax.grad(f_args, argnums=i)
return f_args(*args)
return gradfun
print(grad_all(f)(x))
# 48.0