Search code examples
jax

Is it possible to use jax.vmap for auto-batching if your function isn't jittable?


Is it possible to use vmap for auto-batching if your function isn't jittable?

I have a function that's not jittable:

def testfunc(model, x1, x2, x2_mask):
    ( ... non-jittable stuff with masks ... )

I'm trying to wrap it in vmap so I can benefit from auto-batching as explained here.

So I do:

testfunc_batched = jax.vmap(testfunc, in_axes=(None, 0, 0, 0))

The intention is that in batched mode, each of x1, x2, and x2_mask will have an additional outter dimension, the batching dimension. The model shouldn't be treated differently in batched mode hence the None. Let me know if the syntax isn't right.

I create batches of size one just to test, schematically:

x1s = x1.reshape(1, ...)
x2s = x2.reshape(1, ...)
x2_masks = x2_mask.reshape(1, ...)

testfunc_batched(model, x1s, x2s, x2_masks)

The last line fails with ConcretizationTypeError.

I've recently learned that stuff with masks makes functions not jittable. But does that mean that I also can't use vmap? Or am I doing something wrong?

(There is further context in How to JIT code involving masked arrays without NonConcreteBooleanIndexError?, but you don't have to read that question to understand this one.)


Solution

  • Is it possible to use jax.vmap for auto-batching if your function isn't jittable?

    No. In general, functions which are incompatible with jit will also be incompatible with vmap, because both jit and vmap use the same JAX tracing mechanism to transform the program.