Say I have two arrays:
z = jnp.array([[5.55751118],
[5.18212974],
[4.35981727],
[3.4559711 ],
[3.35750248],
[2.65199945],
[2.02298999],
[1.59444971],
[0.80865185],
[0.77579791]])
z1 = jnp.array([[ 1.58559484],
[ 3.79094097],
[-0.52712522],
[-1.0178286 ],
[-3.51076985],
[ 1.30108161],
[-1.29824303],
[-0.19209007],
[ 0.37451138],
[-2.33619987]])
I would like to start at the first row in array z and find where in the second matrix a second value is within a threshold of this value.
example without @jit: I would like to return the last index of array z1. Value should be -3.51x
init = z[0]
distance = 2.6
new = init - distance
def test():
idx = z>=new
val = z1[jnp.where(idx)[0][-1]]
return val
test()
When using JIT (as needed in a larger scale model)
init = z[0]
distance = 2.6
new = init - distance
@jit
def test():
idx = z>=new
val = z1[jnp.where(idx)[0][-1]]
return val
test()
this error is produced:
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function test at /var/folders/ss/pfgdfm2x7_s4cyw2v0b_t7q80000gn/T/ipykernel_85273/75296347.py:9 for jit. This value became a tracer due to JAX operations on these lines:
operation a:bool[10,1] = ge b c
from line /var/folders/ss/pfgdfm2x7_s4cyw2v0b_t7q80000gn/T/ipykernel_85273/75296347.py:11:10 (test)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
The problem is that jnp.where
returns a dynamically-sized array, and JAX transformations like jit
are not compatible with dynamically-sized arrays (See JAX Sharp Bits: Dynamic Shapes). You can pass a size
argument to jnp.where
to make the result statically sized. Since we don't know how many elements will be returned, we can choose the maximum possible number of returned elements, which is idx.shape[0]
. Since the result will be padded with zeros, the maximum index will give what you're looking for:
@jit
def test():
idx = z>=new
val = z1[jnp.where(idx, size=idx.shape[0])[0].max()]
return val
test()