Search code examples
pythonjax

How to implement nested for loops with branches efficiently in JAX


I am wanting to reimplement a function in jax that loops over a 2d array and modifies the output array at an index that is not necessarily the same as the current iterating index based on conditions. Currently I am implementing this via repeated use of jnp.where for the conditions separately, but the function is ~4x slower than the numba implementation on cpu, on gpu it is ~10x faster - which I suspect is due to the fact that I am iterating over the whole array again for every condition.

The numba implementation is as follows:

from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import numpy as np
import numba as nb

rng = np.random.default_rng()


@nb.njit
def raytrace_np(ir, dx, dy):
    assert ir.ndim == 2
    n, m = ir.shape
    assert ir.shape == dx.shape == dy.shape
    output = np.zeros_like(ir)

    for i in range(ir.shape[0]):
        for j in range(ir.shape[1]):
            dx_ij = dx[i, j]
            dy_ij = dy[i, j]
            
            dxf_ij = np.floor(dx_ij)
            dyf_ij = np.floor(dy_ij)

            ir_ij = ir[i, j]
            index0 = i + int(dyf_ij)
            index1 = j + int(dxf_ij)

            if 0 <= index0 <= n - 1 and 0 <= index1 <= m - 1:
                output[index0, index1] += (
                    ir_ij * (1 - (dx_ij - dxf_ij)) * (1 - (dy_ij - dyf_ij))
                )
            if 0 <= index0 <= n - 1 and 0 <= index1 + 1 <= m - 1:
                output[index0, index1 + 1] += (
                    ir_ij * (dx_ij - dxf_ij) * (1 - (dy_ij - dyf_ij))
                )
            if 0 <= index0 + 1 <= n - 1 and 0 <= index1 <= m - 1:
                output[index0 + 1, index1] += (
                    ir_ij * (1 - (dx_ij - dxf_ij)) * (dy_ij - dyf_ij)
                )
            if 0 <= index0 + 1 <= n - 1 and 0 <= index1 + 1 <= m - 1:
                output[index0 + 1, index1 + 1] += (
                    ir_ij * (dx_ij - dxf_ij) * (dy_ij - dyf_ij)
                )
    return output

and my current jax reimplementation is:

@jax.jit
def raytrace_jax(ir, dx, dy):
    assert ir.ndim == 2
    n, m = ir.shape
    assert ir.shape == dx.shape == dy.shape

    output = jnp.zeros_like(ir)

    dxfloor = jnp.floor(dx)
    dyfloor = jnp.floor(dy)
    
    dxfloor_int = dxfloor.astype(jnp.int64)
    dyfloor_int = dyfloor.astype(jnp.int64)
    
    meshyfloor = dyfloor_int + jnp.arange(n)[:, None]
    meshxfloor = dxfloor_int + jnp.arange(m)[None]

    validx = (meshxfloor >= 0) & (meshxfloor <= m - 1)
    validy = (meshyfloor >= 0) & (meshyfloor <= n - 1)
    validx2 = (meshxfloor + 1 >= 0) & (meshxfloor + 1 <= m - 1)
    validy2 = (meshyfloor + 1 >= 0) & (meshyfloor + 1 <= n - 1)

    validxy = validx & validy
    validx2y = validx2 & validy
    validxy2 = validx & validy2
    validx2y2 = validx2 & validy2
    
    dx_dxfloor = dx - dxfloor
    dy_dyfloor = dy - dyfloor

    output = output.at[
        jnp.where(validxy, meshyfloor, 0), jnp.where(validxy, meshxfloor, 0)
    ].add(
        jnp.where(validxy, ir * (1 - dx_dxfloor) * (1 - dy_dyfloor), 0)
    )
    output = output.at[
        jnp.where(validx2y, meshyfloor, 0),
        jnp.where(validx2y, meshxfloor + 1, 0),
    ].add(jnp.where(validx2y, ir * dx_dxfloor * (1 - dy_dyfloor), 0))
    output = output.at[
        jnp.where(validxy2, meshyfloor + 1, 0),
        jnp.where(validxy2, meshxfloor, 0),
    ].add(jnp.where(validxy2, ir * (1 - dx_dxfloor) * dy_dyfloor, 0))
    output = output.at[
        jnp.where(validx2y2, meshyfloor + 1, 0),
        jnp.where(validx2y2, meshxfloor + 1, 0),
    ].add(jnp.where(validx2y2, ir * dx_dxfloor * dy_dyfloor, 0))
    return output

Test and timings:

shape = 2000, 2000
ir = rng.random(shape)
dx = (rng.random(shape) - 0.5) * 5
dy = (rng.random(shape) - 0.5) * 5

_raytrace_np = raytrace_np(ir, dx, dy)
_raytrace_jax = raytrace_jax(ir, dx, dy).block_until_ready()

assert np.allclose(_raytrace_np, _raytrace_jax)

%timeit raytrace_np(ir, dx, dy)
%timeit raytrace_jax(ir, dx, dy).block_until_ready()

Output:

14.3 ms ± 84.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
62.9 ms ± 187 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

So is there a way to implement this algorithm in jax with performance more comparable to the numba implementation?


Solution

  • The way you implemented it in JAX is pretty close to what I'd recommend. Yes, it's 3x slower than a custom Numba implementation on CPU, but I think for an operation like this, that is to be expected.

    The operation you defined applies specific logic to each individual entry of the array – that is precisely the computational regime that Numba is designed for, and precisely the kind of computation that CPUs were designed for: it's not surprising that with Numba on CPU your computation is very fast.

    I suspect the reason you used Numba rather than NumPy here is that NumPy is not designed for this sort of algorithm: it is an array-oriented language, not an array-element-oriented language. JAX/XLA is more similar to NumPy than to Numba: it is an array-oriented language; it encodes operations across whole arrays at once, rather than choosing a different computation per-element.

    The benefit of this array-oriented computing model becomes really apparent when you move away from CPU and run the code on an accelerator like a GPU or TPU: this hardware is specifically designed for vectorized array operations, which is why you found that the same, array-oriented code was 10x faster on GPU.