Search code examples
pythonjitjax

Iterators in jit JAX functions


I'm new to JAX and reading the docs i found that jitted functions should not contain iterators (section on pure functions)

and they bring this example:

import jax.numpy as jnp
import jax.lax as lax
from jax import jit

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

trying to fiddling with it a little bit in order to see if i can get directly an error instead of undefined behaviour i wrote

@jit
def f(x, arr):
    for i in range(10):
        x += arr[i]
    return x

@jit
def f1(x, arr):
    it = iter(arr)
    for i in range(10):
        x += next(it)
    return x

print(f(0,array)) # 45 as expected
print(f1(0,array)) # still 45 

Is it a "chance" that the jitted function f1() now shows the correct behaviour?


Solution

  • Your code works because of the way that JAX's tracing model works. When JAX's tracing encounters Python control flow, like for loops, the loop is fully evaluated at trace-time (There's some exploration of this in JAX Sharp Bits: Control Flow).

    Because of this, your use of an iterator in this context is fine, because every iteration is evaluated at trace-time, and so next(it) is re-evaluated at every iteration.

    In contrast, when using lax.fori_loop, next(iterator) is only executed a single time and its output is treated as a trace-time constant that will not change during the runtime iterations.