I want to accelerate the nested for-loops in the example below using JAX's jit
method.
However, it takes very long to compile and the runtime after compilation is even slower compared to the version without using jit
.
Do I use jit
correctly? Are there other features in JAX that I should use here instead?
import time
import jax.numpy as jnp
from jax import jit
from jax import random
key = random.PRNGKey(seed=0)
width = 32
height = 64
w = random.normal(key=key, shape=(height, width))
def forward():
a = jnp.zeros(shape=(height, width + 1))
for i in range(height):
a = a.at[i, 0].add(1.0)
for j in range(width):
for i in range(1, height-1):
z = a[i-1, j] * w[i-1, j] \
+ a[i, j] * w[i, j] \
+ a[i+1, j] * w[i+1, j]
a = a.at[i, j+1].set(z)
t0 = time.time()
forward()
print(time.time()-t0)
feedforward_jit = jit(forward)
t0 = time.time()
feedforward_jit()
print(time.time()-t0)
The short answer to your question is: to optimize your loops, you should do everything you can to remove the loops from your program.
JAX (like NumPy) is a language built on array manipulation, and any time you resort to looping over dimensions of arrays, JAX (like NumPy) will be slower than you'd probably like. This is particularly the case during JIT compilation: JAX will flatten loops before sending the operations to XLA, and XLA compilation time scales as roughly the square of the number of operations sent to it, so nested loops are a great way to quickly create very slow compilations.
So how can you avoid these loops? First, let's redefine your function so that it takes inputs and returns outputs (given JAX's dead code elimination and asynchronous dispatch, I don't think your initial benchmarks are telling you what you think they are; see Benchmarking JAX code for some tips):
def forward(w):
height, width = w.shape
a = jnp.zeros(shape=(height, width + 1))
for i in range(height):
a = a.at[i, 0].add(1.0)
for j in range(width):
for i in range(1, height-1):
z = (a[i-1, j] * w[i-1, j]
+ a[i, j] * w[i, j]
+ a[i+1, j] * w[i+1, j])
a = a.at[i, j+1].set(z)
return a
The first loop is a case that can be replaced by a one-line vectorized update: a = a.at[:, 0].set(1)
. Looking at the inner loop of the next block, it appears that the code does a convolution along each column. Let's use jnp.convolve
to do that more efficiently. Using these two optimizations results in this:
def forward2(w):
height, width = w.shape
a = jnp.zeros((height, width + 1)).at[:, 0].set(1)
kernel = jnp.ones(3)
for j in range(width):
conv = jnp.convolve(a[:, j] * w[:, j], kernel, mode='valid')
a = a.at[1:-1, j + 1].set(conv)
return a
Next let's look at the loop over width. Here it's trickier, because each iteration depends on the result of the last. One way we could express that is with lax.scan
, which is one of JAX's built-in control flow operators. You might do it like this:
def forward3(w):
def body(carry, w):
conv = jnp.convolve(carry * w, kernel, mode='valid')
out = jnp.zeros_like(w).at[1:-1].set(conv)
return out, out
init = jnp.ones(w.shape[0])
kernel = jnp.ones(3)
return jnp.vstack([
init, lax.scan(body, jnp.ones(w.shape[0]), w.T)[1]]).T
We can quickly confirm that the three approaches give the same outputs:
width = 32
height = 64
w = random.normal(key=key, shape=(height, width))
result1 = forward(w)
result2 = forward2(w)
result3 = forward3(w)
assert jnp.allclose(result1, result2)
assert jnp.allclose(result2, result3)
Using IPython's %time
magic we can get a rough idea of the computation time of each approach, here on a CPU backend (note the use of block_until_ready()
to account for JAX's Asynchronous dispatch):
%time forward(w).block_until_ready()
# CPU times: user 23 s, sys: 248 ms, total: 23.3 s
# Wall time: 22.9 s
%time forward2(w).block_until_ready()
# CPU times: user 117 ms, sys: 866 µs, total: 118 ms
# Wall time: 118 ms
%time forward3(w).block_until_ready()
# CPU times: user 93.2 ms, sys: 2.96 ms, total: 96.1 ms
# Wall time: 94 ms
You can read more about JAX and control flow at https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow.