Search code examples
pythonjaxautodiff

Getting the expected dimensions of the Jacobian with JAX?


I am trying to get the Jacobian for a simple parameterization function within JAX. The code is as follows:

# imports 
import jax 
import jax.numpy as jnp
from jax import random

# simple parameterization function 
def reparameterize(v_params):
    theta = v_params[0] + jnp.exp(v_params[1]) * eps
    return theta 

Suppose I initialize eps to be a vector of shape (3,) and v_params to be of shape (3, 2):

key = random.PRNGKey(2022)
eps = random.normal(key, shape=(3,))
key, _ = random.split(key)
v_params = random.normal(key, shape=(3, 2))

I want the Jacobian to be an array of shape (3, 2) but by using

jacobian(vmap(reparameterize))(v_params)

returns an array of shape (3, 3, 3, 2). If I re-initialize with only a single eps:

key, _ = random.split(key)
eps = random.normal(key, shape=(1, ))
key, _ = random.split(key)
v_params = random.normal(key, shape=(2, ))

and call jacobian(reparameterize)(v_params) I get what I want, e.g., an array of shape (2, ). Effectively looping over all eps and stacking the results of each Jacobian gives me the desired Jacobian (and shape). What am I missing here? Thanks for your help!


Solution

  • For a function f that maps an input of shape shape_in to an output of shape shape_out, the jacobian will have shape (*shape_out, *shape_in).

    In your case, vmap(reparameterize) takes an array of shape (3, 2) and returns an array of shape (3, 3), so the output of the jacobian is an array of shape (3, 3, 3, 2).

    It's hard to tell from your question what computation you were intending, but if you want a jacobian the same shape as the input, you need a function that maps the input to a scalar. Perhaps the sum is what you had in mind?

    result = jacobian(lambda x: vmap(reparameterize)(x).sum())(v_params)
    print(result.shape)
    # (3, 2)