Since array is immutable in Jax, so when one updates N indexes, it creates N arrays with
x = x.at[idx].set(y)
With hundreds of updates per training cycle, it will ultimately create hundreds of arrays if not millions. This seems a little wasteful, is there a way to update multiple index at one go? Does anyone know if there is overhead or if it's significant? Am I overlook at this?
You can perform multiple updates in a single operation using the syntax you mention. For example:
import jax.numpy as jnp
x = jnp.zeros(10)
idx = jnp.array([3, 5, 7, 9])
y = jnp.array([1, 2, 3, 4])
x = x.at[idx].set(y)
print(x)
# [0. 0. 0. 1. 0. 2. 0. 3. 0. 4.]
You're correct that outside JIT, each update operation will create an array copy. But within JIT-compiled functions, the compiler is able to perform such updates in-place when it is possible (for example, when the original array is not referenced again). You can read more at JAX Sharp Bits: Array Updates.