Is it possible to use vmap for auto-batching if your function isn't jittable?
I have a function that's not jittable:
def testfunc(model, x1, x2, x2_mask):
( ... non-jittable stuff with masks ... )
I'm trying to wrap it in vmap so I can benefit from auto-batching as explained here.
So I do:
testfunc_batched = jax.vmap(testfunc, in_axes=(None, 0, 0, 0))
The intention is that in batched mode, each of x1
, x2
, and x2_mask
will have an additional outter dimension, the batching dimension. The model
shouldn't be treated differently in batched mode hence the None
. Let me know if the syntax isn't right.
I create batches of size one just to test, schematically:
x1s = x1.reshape(1, ...)
x2s = x2.reshape(1, ...)
x2_masks = x2_mask.reshape(1, ...)
testfunc_batched(model, x1s, x2s, x2_masks)
The last line fails with ConcretizationTypeError
.
I've recently learned that stuff with masks makes functions not jittable. But does that mean that I also can't use vmap
? Or am I doing something wrong?
(There is further context in How to JIT code involving masked arrays without NonConcreteBooleanIndexError?, but you don't have to read that question to understand this one.)
Is it possible to use
jax.vmap
for auto-batching if your function isn't jittable?
No. In general, functions which are incompatible with jit
will also be incompatible with vmap
, because both jit
and vmap
use the same JAX tracing mechanism to transform the program.