Search code examples
jax

How to write this function jax.jit-able>


def factory(points):
  points.sort()
  @jax.jit
  def fwd(x):
    for i in range(1, len(points)):
      if x < points[i][0]:
        return (points[i][1] - points[i - 1][1]) / (points[i][0] - points[i - 1][0]) * x + (points[i][1] - (points[i][1] - points[i - 1][1]) / (points[i][0] - points[i - 1][0]) * points[i][0])
    i = len(points) - 1
    return (points[i][1] - points[i - 1][1]) / (points[i][0] - points[i - 1][0]) * x + (points[i][1] - (points[i][1] - points[i - 1][1]) / (points[i][0] - points[i - 1][0]) * points[i][0])
  return fwd

I want to write a function that creates jitted function, given argument: points, a list contain pairs of numbers. I aware that if/else statement can't be jitted and jax.lax.cond() allow conditions but I want something like a break as you can see in the above code. Is there any way to work with conditions?


Solution

  • The challenge in converting this to JAX-compatible is that your function relies on control flow triggered by values in the array; to make this compatible with JAX you should convert it to a vector-based operation. Here's how you might express your operation in terms of np.where rather than for loops:

    def factory_v2(points):
      points.sort()
      def fwd(x):
        matches = np.where(x < points[1:, 0])[0]
        i = matches[0] + 1 if len(matches) else len(points) - 1
        return (points[i, 1] - points[i - 1, 1]) / (points[i, 0] - points[i - 1, 0]) * x + (points[i, 1] - (points[i, 1] - points[i - 1, 1]) / (points[i, 0] - points[i - 1, 0]) * points[i, 0])
      return fwd
    
    x = 2
    points = np.array([[4, 0], [2, 1], [6, 5], [4, 6], [5, 7]])
    print(factory(points)(x))
    # 3.0
    
    print(factory_v2(points)(x))
    # 3.0
    

    This is closer to a JAX-compatible operation, but unfortunately it relies on creating dynamically-shaped arrays. You can get around this by using the size argument to jnp.where. Here's a JAX-compatible version that uses the JAX-only size and fill_value arguments to jnp.where to work around this dynamic size issue:

    import jax
    import jax.numpy as jnp
    
    def factory_jax(points):
      points = jnp.sort(points)
      @jax.jit
      def fwd(x):
        i = 1 + jnp.where(x < points[1:, 0], size=1, fill_value=len(points) - 2)[0][0]
        return (points[i, 1] - points[i - 1, 1]) / (points[i, 0] - points[i - 1, 0]) * x + (points[i, 1] - (points[i, 1] - points[i - 1, 1]) / (points[i, 0] - points[i - 1, 0]) * points[i, 0])
      return fwd
    
    print(factory_jax(points)(x))
    # 3.0
    

    If I've understood the intended input shapes for your code, I believe this should compute the same results as your orginal function.