Search code examples
pythonjax

jax complaining about static start/stop/step


Here is a very simple computation in jax which errors out with complaints about static indices:

def get_slice(ar, k, I):
  return ar[i:i+k]

vec_get_slice = jax.vmap(get_slice, in_axes=(None, None, 0))

arr = jnp.array([1, 2,3, 4, 5])

vec_get_slice(arr, 2, jnp.arange(3))
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-32-6c60650ce6b7> in <cell line: 1>()
----> 1 vec_get_slice(arr, 2, jnp.arange(3))

    [... skipping hidden 3 frame]

4 frames
<ipython-input-29-9528369725c2> in get_slice(ar, k, i)
      1 def get_slice(ar, k, i):
----> 2   return ar[i:i+k]

/usr/local/lib/python3.10/dist-packages/jax/_src/array.py in __getitem__(self, idx)
    346           return out
    347 
--> 348     return lax_numpy._rewriting_take(self, idx)
    349 
    350   def __iter__(self):

/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   4602 
   4603   treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 4604   return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   4605                  unique_indices, mode, fill_value)
   4606 

/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
   4611             unique_indices, mode, fill_value):
   4612   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 4613   indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
   4614   y = arr
   4615 

/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in _index_to_gather(x_shape, idx, normalize_indices)
   4854                "dynamic_update_slice (JAX does not support dynamically sized "
   4855                "arrays within JIT compiled functions).")
-> 4856         raise IndexError(msg)
   4857 
   4858       start, step, slice_size = _preprocess_slice(i, x_shape[x_axis])

Horrible error output below. I am obviously missing something simple, but what?


IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
  val = Array([0, 1, 2], dtype=int32)
  batch_dim = 0, Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
  val = Array([2, 3, 4], dtype=int32)
  batch_dim = 0, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

Solution

  • Indices passed to slices in JAX must be static. Values that are mapped over in vmap are not static: because you're mapping over the start indices, your indices are not static and you see this error.

    There is good news though: the size of your subarray is controlled by k, which is unmapped in your code and therefore static; it's only the location of the slice (given by I) that is dynamic. This is exactly the situation that jax.lax.dynamic_slicewas designed for, and so you can rewrite your code like this:

    import jax
    import jax.numpy as jnp
    
    def get_slice(ar, k, I):
      return jax.lax.dynamic_slice(ar, (I,), (k,))
    
    vec_get_slice = jax.vmap(get_slice, in_axes=(None, None, 0))
    
    arr = jnp.array([1, 2, 3, 4, 5])
    
    vec_get_slice(arr, 2, jnp.arange(3))
    # Array([[1, 2],
    #        [2, 3],
    #        [3, 4]], dtype=int32)