Search code examples
jax

Update JAX array based on values in another array


I have a Jax array X like this:

[[[0. 0. 0.]
 [0. 0. 0.]]

 [[0. 0. 0.]
 [0. 0. 0.]]

 [[0. 0. 0.]
 [0. 0. 0.]]]

How do I set the values of this array to 1, whose indices are given by array Y:

[[[1 2]
 [1 2]]

 [[0 2]
 [0 1]]

 [[1 0]
 [1 0]]]

Desired output:

([[[0., 1., 1.],
        [0., 1., 1.]],

       [[1., 0., 1.],
        [1., 1., 0.]],

       [[1., 1., 0.],
        [1., 1., 0.]]]

Solution

  • There are a couple ways to approach this. First let's define the arrays:

    import jax
    import jax.numpy as jnp
    
    x = jnp.zeros((3, 2, 3))
    indices = jnp.array([[[1, 2],
                          [1, 2]],
                         [[0, 2],
                          [0, 1]],
                         [[1, 0],
                          [1, 0]]])
    

    One way to do this is to use typical numpy-style broadcasting of indices. It might look like this:

    i = jnp.arange(3).reshape(3, 1, 1)
    j = jnp.arange(2).reshape(2, 1)
    x = x.at[i, j, indices].set(1)
    print(x)
    
    [[[0. 1. 1.]
      [0. 1. 1.]]
    
     [[1. 0. 1.]
      [1. 1. 0.]]
    
     [[1. 1. 0.]
      [1. 1. 0.]]]
    

    Another option is to use a double-vmap transformation to compute the batched indices:

    f = jax.vmap(jax.vmap(lambda x, i: x.at[i].set(1)))
    print(f(x, indices))
    
    [[[0. 1. 1.]
      [0. 1. 1.]]
    
     [[1. 0. 1.]
      [1. 1. 0.]]
    
     [[1. 1. 0.]
      [1. 1. 0.]]]