Search code examples
pythonjitjax

How to return last index with jnp.where in jit function


Say I have two arrays:

z = jnp.array([[5.55751118],
              [5.18212974],
              [4.35981727],
              [3.4559711 ],
              [3.35750248],
              [2.65199945],
              [2.02298999],
              [1.59444971],
              [0.80865185],
              [0.77579791]])

z1 = jnp.array([[ 1.58559484],
               [ 3.79094097],
               [-0.52712522],
               [-1.0178286 ],
               [-3.51076985],
               [ 1.30108161],
               [-1.29824303],
               [-0.19209007],
               [ 0.37451138],
               [-2.33619987]])

I would like to start at the first row in array z and find where in the second matrix a second value is within a threshold of this value.

example without @jit: I would like to return the last index of array z1. Value should be -3.51x

init = z[0]
distance = 2.6
new = init - distance 

def test():
    idx = z>=new
    val = z1[jnp.where(idx)[0][-1]]
    return val
test()

When using JIT (as needed in a larger scale model)

init = z[0]
distance = 2.6
new = init - distance 

@jit
def test():
    idx = z>=new
    val = z1[jnp.where(idx)[0][-1]]
    return val
test()

this error is produced:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function test at /var/folders/ss/pfgdfm2x7_s4cyw2v0b_t7q80000gn/T/ipykernel_85273/75296347.py:9 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:bool[10,1] = ge b c
    from line /var/folders/ss/pfgdfm2x7_s4cyw2v0b_t7q80000gn/T/ipykernel_85273/75296347.py:11:10 (test)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

Solution

  • The problem is that jnp.where returns a dynamically-sized array, and JAX transformations like jit are not compatible with dynamically-sized arrays (See JAX Sharp Bits: Dynamic Shapes). You can pass a size argument to jnp.where to make the result statically sized. Since we don't know how many elements will be returned, we can choose the maximum possible number of returned elements, which is idx.shape[0]. Since the result will be padded with zeros, the maximum index will give what you're looking for:

    @jit
    def test():
        idx = z>=new
        val = z1[jnp.where(idx, size=idx.shape[0])[0].max()]
        return val
    test()