Search code examples
jax

How to slice jax arrays using jax tracer?


I am trying to modify a code base to create a subarray using an existing array and indices in the form of Jax tracer. When I try to pass these Jax tracers directly for indices. I get the following error:

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Tracedwith, Tracedwith, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

What is a possible workaround/ solution for this?


Solution

  • There are two main workarounds here that may be applicable depending on your problem: using static indices, or using dynamic_slice.

    Quick background: one constraint of arrays used in JAX transformations like jit, vmap, etc. is that they must be statically shaped (see JAX Sharp Bits: Dynamic Shapes for some discussion of this).

    With that in mind, a function like f below will always fail, because i and j are non-static variables and so the shape of the returned array cannot be known at compile time:

    @jit
    def f(x, i, j):
      return x[i:j]
    

    One workaround for this is to make i and j static arguments in jit, so that the shape of the returned array will be static:

    @partial(jit, static_argnames=['i', 'j'])
    def f(x, i, j):
      return x[i:j]
    

    That's the only possible workaround to use jit in such a situation, because of the static shape constraint.

    Another flavor of slicing problem that can lead to the same error might look like this:

    @jit
    def f(x, i):
      return x[i:i + 5]
    

    This will also result in a non-static index error. It could be fixed as above by marking i as static, but there is more information here: assuming that 0 <= i < len(x) - 5 holds, we know that the shape of the output array is (5,). This is a case where jax.lax.dynamic_slice is applicable (when you have a fixed slice size at a dynamic location):

    @jit
    def f(x, i):
      return jax.lax.dynamic_slice(x, (i,), (5,))
    

    Note that this will have different semantics than x[i:i + 5] in cases where the slice overruns the bounds of the array, but in most cases of interest it is equivalent.

    There are other examples where neither of these two workarounds are applicable, for example when your program logic is predicated on creating dynamic-length arrays. In these cases, there is no easy work-around, and your best bet is to either (1) re-write your algorithm in terms of static array shapes, perhaps using padded array representations, or (2) not use JAX.