Search code examples
jax

Is there a way to update multiple indexes of Jax array at once?


Since array is immutable in Jax, so when one updates N indexes, it creates N arrays with

x = x.at[idx].set(y)

With hundreds of updates per training cycle, it will ultimately create hundreds of arrays if not millions. This seems a little wasteful, is there a way to update multiple index at one go? Does anyone know if there is overhead or if it's significant? Am I overlook at this?


Solution

  • You can perform multiple updates in a single operation using the syntax you mention. For example:

    import jax.numpy as jnp
    
    x = jnp.zeros(10)
    idx = jnp.array([3, 5, 7, 9])
    y = jnp.array([1, 2, 3, 4])
    
    x = x.at[idx].set(y)
    print(x)
    # [0. 0. 0. 1. 0. 2. 0. 3. 0. 4.]
    

    You're correct that outside JIT, each update operation will create an array copy. But within JIT-compiled functions, the compiler is able to perform such updates in-place when it is possible (for example, when the original array is not referenced again). You can read more at JAX Sharp Bits: Array Updates.