I have a Jax array X like this:
[[[0. 0. 0.]
[0. 0. 0.]]
[[0. 0. 0.]
[0. 0. 0.]]
[[0. 0. 0.]
[0. 0. 0.]]]
How do I set the values of this array to 1, whose indices are given by array Y:
[[[1 2]
[1 2]]
[[0 2]
[0 1]]
[[1 0]
[1 0]]]
Desired output:
([[[0., 1., 1.],
[0., 1., 1.]],
[[1., 0., 1.],
[1., 1., 0.]],
[[1., 1., 0.],
[1., 1., 0.]]]
There are a couple ways to approach this. First let's define the arrays:
import jax
import jax.numpy as jnp
x = jnp.zeros((3, 2, 3))
indices = jnp.array([[[1, 2],
[1, 2]],
[[0, 2],
[0, 1]],
[[1, 0],
[1, 0]]])
One way to do this is to use typical numpy-style broadcasting of indices. It might look like this:
i = jnp.arange(3).reshape(3, 1, 1)
j = jnp.arange(2).reshape(2, 1)
x = x.at[i, j, indices].set(1)
print(x)
[[[0. 1. 1.]
[0. 1. 1.]]
[[1. 0. 1.]
[1. 1. 0.]]
[[1. 1. 0.]
[1. 1. 0.]]]
Another option is to use a double-vmap
transformation to compute the batched indices:
f = jax.vmap(jax.vmap(lambda x, i: x.at[i].set(1)))
print(f(x, indices))
[[[0. 1. 1.]
[0. 1. 1.]]
[[1. 0. 1.]
[1. 1. 0.]]
[[1. 1. 0.]
[1. 1. 0.]]]