Search code examples
pythonjaxgoogle-jax

Rewriting for loop with jax.lax.scan


I'm having troubles understanding the JAX documentation. Can somebody give me a hint on how to rewrite simple code like this with jax.lax.scan?

numbers = numpy.array( [ [3.0, 14.0], [15.0, -7.0], [16.0, -11.0] ])
evenNumbers = 0
for row in numbers:
      for n in row:
         if n % 2 == 0:
            evenNumbers += 1

Solution

  • Assuming a solution should demonstrate the concepts rather than optimize the example shown, the function to be jax.lax.scanned must match the expected signature and any dynamic condition has to be replaced with jax.lax.cond. The code below is the closest to the original I could think of, but please be aware that I'm anything but an jaxpert.

    import jax
    import jax.numpy as jnp
    
    def f(carry, row):
    
        even = 0
        for n in row:
            even += jax.lax.cond(n % 2 == 0, lambda: 1, lambda: 0)
    
        return carry + even, even
    
    numbers = jnp.array([[3.0, 14.0], [15.0, -7.0], [16.0, -11.0]])
    jax.lax.scan(f, 0, numbers)
    

    Output

    (DeviceArray(2, dtype=int32, weak_type=True),
     DeviceArray([1, 0, 1], dtype=int32, weak_type=True))