I am trying to get the Jacobian for a simple parameterization function within JAX. The code is as follows:
# imports
import jax
import jax.numpy as jnp
from jax import random
# simple parameterization function
def reparameterize(v_params):
theta = v_params[0] + jnp.exp(v_params[1]) * eps
return theta
Suppose I initialize eps
to be a vector of shape (3,)
and v_params
to be of shape (3, 2)
:
key = random.PRNGKey(2022)
eps = random.normal(key, shape=(3,))
key, _ = random.split(key)
v_params = random.normal(key, shape=(3, 2))
I want the Jacobian to be an array of shape (3, 2)
but by using
jacobian(vmap(reparameterize))(v_params)
returns an array of shape (3, 3, 3, 2)
. If I re-initialize with only a single eps
:
key, _ = random.split(key)
eps = random.normal(key, shape=(1, ))
key, _ = random.split(key)
v_params = random.normal(key, shape=(2, ))
and call jacobian(reparameterize)(v_params)
I get what I want, e.g., an array of shape (2, )
. Effectively looping over all eps
and stacking the results of each Jacobian gives me the desired Jacobian (and shape). What am I missing here? Thanks for your help!
For a function f
that maps an input of shape shape_in
to an output of shape shape_out
, the jacobian will have shape (*shape_out, *shape_in)
.
In your case, vmap(reparameterize)
takes an array of shape (3, 2)
and returns an array of shape (3, 3)
, so the output of the jacobian is an array of shape (3, 3, 3, 2)
.
It's hard to tell from your question what computation you were intending, but if you want a jacobian the same shape as the input, you need a function that maps the input to a scalar. Perhaps the sum is what you had in mind?
result = jacobian(lambda x: vmap(reparameterize)(x).sum())(v_params)
print(result.shape)
# (3, 2)