How do I do a jax get from a masked index?
The code below works without jit.
x = jnp.arange(25).reshape((5,5))
coords = jnp.array([
[1,2],
[2,3],
[1,2],
[1,2],
])
coords_mask = jnp.array([True, True, False, True])
@jax.jit
def masked_gather(x, coords, coords_mask):
coords_masked = coords[coords_mask]
return x.at[coords_masked[:, 0], coords_masked[:, 1]].get()
masked_gather(x, coords, coords_mask)
Fails with NonConcreteBooleanIndexError
.
Should return Array([ 7, 13, 7], dtype=int32)
There is no way to execute this function in a JIT-compatible way, because JAX does not support compilation of programs with dynamic shapes. In your case, the size of the returned array depends on the number of True
elements in coords_mask
, and so the shape is dynamic by definition.
See JAX Sharp Bits: Dynamic Shapes for more information.
Depending on what you are doing with the resulting value, there are a number of available approaches to work around this: for example, if the shape is truly unknown, you could return an array padded with zeros; it might look something like this:
@jax.jit
def masked_gather_padded(x, coords, coords_mask, fill_value=0):
coords_masked = jnp.where(coords_mask[:, None], coords, max(x.shape))
order = jnp.argsort(~coords_mask)
result = x.at[coords_masked[:, 0], coords_masked[:, 1]].get(mode='fill', fill_value=fill_value)
return result[order]
masked_gather_padded(x, coords, coords_mask)
# Array([ 7, 13, 7, 0], dtype=int32)
Alternatively, if the number of True
entries in the mask is known a priori, you could modify the function to accept a static size
argument and use that to construct an appropriate output. It might look something like this:
from functools import partial
@partial(jax.jit, static_argnames=['size'])
def masked_gather_with_size(x, coords, coords_mask, *, size):
coords_masked = jnp.where(coords_mask[:, None], coords, max(x.shape))
order = jnp.argsort(~coords_mask)
result = x.at[coords_masked[:, 0], coords_masked[:, 1]].get(mode='drop')
return result[order[:size]]
masked_gather_with_size(x, coords, coords_mask, size=3)
# Array([ 7, 13, 7], dtype=int32)
The best approach will depend on your application.