I am trying to apply a number of in place updates to a 2D matrix.
It appears that using jit
to the in place update does not have any effect in computation time (which is many orders of magnitude longer than the equivalent numpy
Here is code that demonstrates my problem and research.
node_count = 10000
b = onp.zeros([node_count,node_count])
print("`numpy` in place update.")
%timeit b[1,1] = 1.
# 86.9 ns ± 1.42 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
a = np.zeros([node_count,node_count])
print("`jax.np` in place update.")
%timeit a.at[1,1].set(1.)
# 112 ms ± 14.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
def update(mat, index, val):
return mat.at[tuple(index)].set(val)
update_jit = jit(update)
# Run once for trace.
update_jit(a, [1,1], 1.).block_until_ready()
print("`jax.np` jit in place update.")
%timeit update_jit(a, [1,1],1.).block_until_ready()
# 99.6 ms ± 358 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
This has nothing to do with inlining of inplace updates. This has to do with the fact that, unless otherwise requested, a JIT-compiled function will always return its result in a new, distinct buffer. The only exception to this is if you use buffer donation to explicitly mark that the input buffer can be re-used in the output:
update_jit = jit(update, donate_argnums=[0])
Note, however, that buffer donation is currently only available on GPU and TPU runtimes.
You'll not be able to use %timeit
in this case, because the donated input buffer is no longer available for use after the first iteration, but you can confirm via %time
that this improves the computation speed:
# Following is run on a Colab T4 GPU runtime
update_jit = jit(update)
_ = update_jit(b, [1,1], 1.)
%time _ = update_jit(b, [1,1], 1.).block_until_ready()
# CPU times: user 607 µs, sys: 112 µs, total: 719 µs
# Wall time: 5.89 ms
update_jit_donate = jit(update, donate_argnums=[0])
b = update_jit_donate(b, [1,1], 1.)
%time _ = update_jit_donate(b, [1,1], 1.).block_until_ready()
# CPU times: user 467 µs, sys: 86 µs, total: 553 µs
# Wall time: 332 µs
The buffer donation version is still quite a bit slower than the NumPy version, but this is expected for the reasons discussed at FAQ: Is JAX Faster Than Numpy?.
I suspect you're performing these micro-benchmarks to assure yourself that the compiler performs updates in-place within a JIT-compiled sequence of operations rather than making internal copies, as is mentioned in Sharp Bits: Array Updates. If so, you can confirm this by other means; for example:
def sum(x):
return x.sum()
def update_and_sum(x):
return x.at[0, 0].set(1).sum()
_ = sum(b)
%timeit sum(b).block_until_ready()
# 1.66 ms ± 7.55 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
_ = update_and_sum(b)
%timeit update_and_sum(b).block_until_ready()
# 1.66 ms ± 20.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
The identical timings here show that the update operation is being performed in-place rather than causing the input buffer to be copied.