what is the proper way to update multiple indexes of 2D (or multiple dimension) Jax array at once?
This is a follow up question to my previous on batch update for an 1D Jax array with the goal to avoid creating millions of arrays during training.
I have tried:
x = jnp.zeros((3,3))
# Update 1 index at a time
x = x.at[2, 2].set(1) # or x = x.at[(2, 2)].set(1)
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 1.]]
# Nice, it works.
# but how about 2 indexes at the same time?
x = jnp.zeros((3,3))
x = x.at[(1, 0), (0, 1) ].set([1, 3])
print(x)
[[0. 3. 0.]
[1. 0. 0.]
[0. 0. 0.]]
It works again, but when I tried to update 3 or more indexes,
x = x.at[(1, 0), (0, 1), (1,1) ].set([1, 3, 6])
print(x)
IndexError: Too many indices for array: 3 non-None/Ellipsis indices for dim 2.
I have spent some time browsing through Jax's documentation, but I couldn't find the best way. Any help?
The values you give in .at are rows and columns, rather than pairs of rows/columns. This is hinted at in the error message referring to dim 2 (dim 0 is rows, dim 1 is columns, there is no dim 2). This should give the desired behavior
x = x.at[(1, 0, 1), (0, 1, 1) ].set([1, 3, 6])
[[0. 3. 0.]
[1. 6. 0.]
[0. 0. 0.]]