I am wanting to reimplement a function in jax that loops over a 2d array and modifies the output array at an index that is not necessarily the same as the current iterating index based on conditions. Currently I am implementing this via repeated use of jnp.where
for the conditions separately, but the function is ~4x slower than the numba implementation on cpu, on gpu it is ~10x faster - which I suspect is due to the fact that I am iterating over the whole array again for every condition.
The numba implementation is as follows:
from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import numpy as np
import numba as nb
rng = np.random.default_rng()
@nb.njit
def raytrace_np(ir, dx, dy):
assert ir.ndim == 2
n, m = ir.shape
assert ir.shape == dx.shape == dy.shape
output = np.zeros_like(ir)
for i in range(ir.shape[0]):
for j in range(ir.shape[1]):
dx_ij = dx[i, j]
dy_ij = dy[i, j]
dxf_ij = np.floor(dx_ij)
dyf_ij = np.floor(dy_ij)
ir_ij = ir[i, j]
index0 = i + int(dyf_ij)
index1 = j + int(dxf_ij)
if 0 <= index0 <= n - 1 and 0 <= index1 <= m - 1:
output[index0, index1] += (
ir_ij * (1 - (dx_ij - dxf_ij)) * (1 - (dy_ij - dyf_ij))
)
if 0 <= index0 <= n - 1 and 0 <= index1 + 1 <= m - 1:
output[index0, index1 + 1] += (
ir_ij * (dx_ij - dxf_ij) * (1 - (dy_ij - dyf_ij))
)
if 0 <= index0 + 1 <= n - 1 and 0 <= index1 <= m - 1:
output[index0 + 1, index1] += (
ir_ij * (1 - (dx_ij - dxf_ij)) * (dy_ij - dyf_ij)
)
if 0 <= index0 + 1 <= n - 1 and 0 <= index1 + 1 <= m - 1:
output[index0 + 1, index1 + 1] += (
ir_ij * (dx_ij - dxf_ij) * (dy_ij - dyf_ij)
)
return output
and my current jax reimplementation is:
@jax.jit
def raytrace_jax(ir, dx, dy):
assert ir.ndim == 2
n, m = ir.shape
assert ir.shape == dx.shape == dy.shape
output = jnp.zeros_like(ir)
dxfloor = jnp.floor(dx)
dyfloor = jnp.floor(dy)
dxfloor_int = dxfloor.astype(jnp.int64)
dyfloor_int = dyfloor.astype(jnp.int64)
meshyfloor = dyfloor_int + jnp.arange(n)[:, None]
meshxfloor = dxfloor_int + jnp.arange(m)[None]
validx = (meshxfloor >= 0) & (meshxfloor <= m - 1)
validy = (meshyfloor >= 0) & (meshyfloor <= n - 1)
validx2 = (meshxfloor + 1 >= 0) & (meshxfloor + 1 <= m - 1)
validy2 = (meshyfloor + 1 >= 0) & (meshyfloor + 1 <= n - 1)
validxy = validx & validy
validx2y = validx2 & validy
validxy2 = validx & validy2
validx2y2 = validx2 & validy2
dx_dxfloor = dx - dxfloor
dy_dyfloor = dy - dyfloor
output = output.at[
jnp.where(validxy, meshyfloor, 0), jnp.where(validxy, meshxfloor, 0)
].add(
jnp.where(validxy, ir * (1 - dx_dxfloor) * (1 - dy_dyfloor), 0)
)
output = output.at[
jnp.where(validx2y, meshyfloor, 0),
jnp.where(validx2y, meshxfloor + 1, 0),
].add(jnp.where(validx2y, ir * dx_dxfloor * (1 - dy_dyfloor), 0))
output = output.at[
jnp.where(validxy2, meshyfloor + 1, 0),
jnp.where(validxy2, meshxfloor, 0),
].add(jnp.where(validxy2, ir * (1 - dx_dxfloor) * dy_dyfloor, 0))
output = output.at[
jnp.where(validx2y2, meshyfloor + 1, 0),
jnp.where(validx2y2, meshxfloor + 1, 0),
].add(jnp.where(validx2y2, ir * dx_dxfloor * dy_dyfloor, 0))
return output
Test and timings:
shape = 2000, 2000
ir = rng.random(shape)
dx = (rng.random(shape) - 0.5) * 5
dy = (rng.random(shape) - 0.5) * 5
_raytrace_np = raytrace_np(ir, dx, dy)
_raytrace_jax = raytrace_jax(ir, dx, dy).block_until_ready()
assert np.allclose(_raytrace_np, _raytrace_jax)
%timeit raytrace_np(ir, dx, dy)
%timeit raytrace_jax(ir, dx, dy).block_until_ready()
Output:
14.3 ms ± 84.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
62.9 ms ± 187 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
So is there a way to implement this algorithm in jax with performance more comparable to the numba implementation?
The way you implemented it in JAX is pretty close to what I'd recommend. Yes, it's 3x slower than a custom Numba implementation on CPU, but I think for an operation like this, that is to be expected.
The operation you defined applies specific logic to each individual entry of the array – that is precisely the computational regime that Numba is designed for, and precisely the kind of computation that CPUs were designed for: it's not surprising that with Numba on CPU your computation is very fast.
I suspect the reason you used Numba rather than NumPy here is that NumPy is not designed for this sort of algorithm: it is an array-oriented language, not an array-element-oriented language. JAX/XLA is more similar to NumPy than to Numba: it is an array-oriented language; it encodes operations across whole arrays at once, rather than choosing a different computation per-element.
The benefit of this array-oriented computing model becomes really apparent when you move away from CPU and run the code on an accelerator like a GPU or TPU: this hardware is specifically designed for vectorized array operations, which is why you found that the same, array-oriented code was 10x faster on GPU.