Search code examples
cachingjaxhaiku

Does Haiku cache parameters when combined with jax.vmap?


I have a haiku Module with a call function as follows

class MyModule(hk.Module):
    __call__(self, x):
        A = hk.get_parameter("A", shape=[self.Ashape], init=A_init)
        B = hk.get_parameter("B", shape=[self.Bshape], init=B_init)
        C = self.demanding_computation(A, B)

        res = easy_computation(C, x)
        return res

I use this module via

def _forward(x):
    module = MyModule()
    return module(x)


forward = hk.without_apply_rng(hk.transform(_forward))
x_test = jnp.ones(1)
params = forward.init(jax.random.PRNGKey(42), x_test)
f = jax.vmap(forward.apply, in_axes=(None, 0))

Then I apply f with the same params to many different x. Is the demanding_computation (that is not depending on x) then cached within the jax.vmap call? If not, what is the correct pattern to separate these computations and get demanding_computation cached?

I have tried to test this by adding a print statement from jax.experimental.host_callback:

    def demanding_computation(self, A, B):
        C = compute(A, B)
        id_print(C)
        return C

and it indeed only printed once. Is that sufficient evidence that this computation is actually cached or is only the printing omitted in subsequent iterations?


Solution

  • demanding_computation will only be called once, but not because of caching.

    vmap doesn't loop over the batched axes, it replaces the operations with vectorized versions (e.g. scalar additions become vector additions). Since demanding_computation doesn't involve inputs with batch axes it won't be modified by this use of vmap. (Even if it did, it would still only be run once, it would just be a vectorized version).