I have a haiku Module with a call function as follows
class MyModule(hk.Module):
__call__(self, x):
A = hk.get_parameter("A", shape=[self.Ashape], init=A_init)
B = hk.get_parameter("B", shape=[self.Bshape], init=B_init)
C = self.demanding_computation(A, B)
res = easy_computation(C, x)
return res
I use this module via
def _forward(x):
module = MyModule()
return module(x)
forward = hk.without_apply_rng(hk.transform(_forward))
x_test = jnp.ones(1)
params = forward.init(jax.random.PRNGKey(42), x_test)
f = jax.vmap(forward.apply, in_axes=(None, 0))
Then I apply f with the same params
to many different x
. Is the demanding_computation
(that is not depending on x
) then cached within the jax.vmap
call? If not, what is the correct pattern to separate these computations and get demanding_computation
cached?
I have tried to test this by adding a print statement from jax.experimental.host_callback
:
def demanding_computation(self, A, B):
C = compute(A, B)
id_print(C)
return C
and it indeed only printed once. Is that sufficient evidence that this computation is actually cached or is only the printing omitted in subsequent iterations?
demanding_computation
will only be called once, but not because of caching.
vmap
doesn't loop over the batched axes, it replaces the operations with vectorized versions (e.g. scalar additions become vector additions). Since demanding_computation
doesn't involve inputs with batch axes it won't be modified by this use of vmap
. (Even if it did, it would still only be run once, it would just be a vectorized version).