Search code examples
pythonprobability-densityprobability-distributionautomatic-differentiationjax

How to compute the joint probability density function from a joint cumulative density function in Jax?


I have a joint cumulative density function defined in python as a function of a jax array and returning a single value. Something like:

def cumulative(inputs: array) -> float:
    ...

To have the gradient, I know I can just do grad(cumulative), but that is only giving me the first-order partial derivatives of cumulative with respect to the input variables. Instead, what I would like to do is to compute is this, assuming F is my function and f the joint probability density function:

formula

The order of the partial derivation doesn't matter.

So, I have several questions:

  • how to compute this efficiently in Jax? I assume I cannot just call grad n times
  • once the resulting function is computed, will the resulting function have a higher call complexity than the original function (is it increased by O(n), or is it constant, or something else)?
  • alternatively, how can I compute a single partial derivative with respect to only one of the variable of the input array, as opposed to the entire array? (And I will just repeat this n times, once per variable)

Solution

  • JAX generally treats gradients as being with respect to individual arguments, not elements within arguments. Within this context, one built-in function that is similar to what you want to do (but not exactly the same) is jax.hessian, which computes the hessian matrix of second derivatives; for example:

    import jax
    import jax.numpy as jnp
    
    def f(x):
      return jnp.prod(x ** 2)
    
    x = jnp.arange(1.0, 4.0)
    print(jax.hessian(f)(x))
    # [[72. 72. 48.]
    #  [72. 18. 24.]
    #  [48. 24.  8.]]
    

    For higher-order derivatives with respect to individual elements of the array, I think you'll have to manually nest the gradients. You could do so with a helper function that looks something like this:

    def grad_all(f):
      def gradfun(x):
        args = tuple(x)
        f_args = lambda *args: f(jnp.array(args))
        for i in range(len(args)):
          f_args = jax.grad(f_args, argnums=i)
        return f_args(*args)
      return gradfun
    
    print(grad_all(f)(x))
    # 48.0