def factory(points):
points.sort()
@jax.jit
def fwd(x):
for i in range(1, len(points)):
if x < points[i][0]:
return (points[i][1] - points[i - 1][1]) / (points[i][0] - points[i - 1][0]) * x + (points[i][1] - (points[i][1] - points[i - 1][1]) / (points[i][0] - points[i - 1][0]) * points[i][0])
i = len(points) - 1
return (points[i][1] - points[i - 1][1]) / (points[i][0] - points[i - 1][0]) * x + (points[i][1] - (points[i][1] - points[i - 1][1]) / (points[i][0] - points[i - 1][0]) * points[i][0])
return fwd
I want to write a function that creates jitted function, given argument: points, a list contain pairs of numbers. I aware that if/else statement can't be jitted and jax.lax.cond() allow conditions but I want something like a break as you can see in the above code. Is there any way to work with conditions?
The challenge in converting this to JAX-compatible is that your function relies on control flow triggered by values in the array; to make this compatible with JAX you should convert it to a vector-based operation. Here's how you might express your operation in terms of np.where
rather than for
loops:
def factory_v2(points):
points.sort()
def fwd(x):
matches = np.where(x < points[1:, 0])[0]
i = matches[0] + 1 if len(matches) else len(points) - 1
return (points[i, 1] - points[i - 1, 1]) / (points[i, 0] - points[i - 1, 0]) * x + (points[i, 1] - (points[i, 1] - points[i - 1, 1]) / (points[i, 0] - points[i - 1, 0]) * points[i, 0])
return fwd
x = 2
points = np.array([[4, 0], [2, 1], [6, 5], [4, 6], [5, 7]])
print(factory(points)(x))
# 3.0
print(factory_v2(points)(x))
# 3.0
This is closer to a JAX-compatible operation, but unfortunately it relies on creating dynamically-shaped arrays. You can get around this by using the size
argument to jnp.where
. Here's a JAX-compatible version that uses the JAX-only size
and fill_value
arguments to jnp.where
to work around this dynamic size issue:
import jax
import jax.numpy as jnp
def factory_jax(points):
points = jnp.sort(points)
@jax.jit
def fwd(x):
i = 1 + jnp.where(x < points[1:, 0], size=1, fill_value=len(points) - 2)[0][0]
return (points[i, 1] - points[i - 1, 1]) / (points[i, 0] - points[i - 1, 0]) * x + (points[i, 1] - (points[i, 1] - points[i - 1, 1]) / (points[i, 0] - points[i - 1, 0]) * points[i, 0])
return fwd
print(factory_jax(points)(x))
# 3.0
If I've understood the intended input shapes for your code, I believe this should compute the same results as your orginal function.