I have a vector-jacobian product that I want to compute.
The function func
takes four arguments, the final two of which are static:
def func(variational_params, e, A, B):
...
return model_params, dlogp, ...
The function jits perfectly fine via
func_jitted = jit(func, static_argnums=(2, 3))
The primals are the variational_params
, and the cotangents are dlogp
(the second output of the function).
Calculating the vector-jacobian product naively (by forming the jacobian) works fine:
jacobian_func = jacobian(func_jitted, argnums=0, has_aux=True)
jacobian_jitted = jit(jacobian_func, static_argnums=(2, 3))
jac, func_output = jacobian_jitted(variational_params, e, A, B)
naive_vjp = func_output.T @ jac
When trying to form the vjp
in an efficient manner via
f_eval, vjp_function, aux_output = vjp(func_jitted, variational_params, e, A, B, has_aux=True)
I get the following error:
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.ad.JVPTracer'> for function func is non-hashable.
I am a little confused as the function func
jitted perfectly fine... there is no option for adding static_argnums
to the vjp
function, so I am not too sure what this means.
For higher-level transformation APIs like jit
, JAX generally provides a mechanism like static_argnums
or argnums
to allow specification of static vs. dynamic variables.
For lower-level transformation routines like jvp
and vjp
, these mechanisms are not provided, but you can still accomplish the same thing by passing partially-evaluated functions. For example:
from functools import partial
f_eval, vjp_function, aux_output = vjp(partial(func_jitted, A=A, B=B), variational_params, e, has_aux=True)
This is effectively how transformation parameters like argnums
and static_argnums
are implemented under the hood.