Search code examples
jax

What is the proper way to update multiple indexes of 2D (or multiple dimensions) Jax array at once?


what is the proper way to update multiple indexes of 2D (or multiple dimension) Jax array at once?

This is a follow up question to my previous on batch update for an 1D Jax array with the goal to avoid creating millions of arrays during training.

I have tried:

x = jnp.zeros((3,3))

# Update 1 index at a time
x = x.at[2, 2].set(1) # or x = x.at[(2, 2)].set(1)
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 1.]]
# Nice, it works.
# but how about 2 indexes at the same time?
x = jnp.zeros((3,3))
x = x.at[(1, 0), (0, 1) ].set([1, 3])
print(x)
[[0. 3. 0.]
 [1. 0. 0.]
 [0. 0. 0.]]

It works again, but when I tried to update 3 or more indexes,
x = x.at[(1, 0), (0, 1), (1,1) ].set([1, 3, 6])
print(x)
IndexError: Too many indices for array: 3 non-None/Ellipsis indices for dim 2.

I have spent some time browsing through Jax's documentation, but I couldn't find the best way. Any help?


Solution

  • The values you give in .at are rows and columns, rather than pairs of rows/columns. This is hinted at in the error message referring to dim 2 (dim 0 is rows, dim 1 is columns, there is no dim 2). This should give the desired behavior

    x = x.at[(1, 0, 1), (0, 1, 1) ].set([1, 3, 6])
    [[0. 3. 0.]
     [1. 6. 0.]
     [0. 0. 0.]]