Search code examples
jax

Count onto 2D JAX coordinates of another 2D array


I have

x = jnp.zeros((5,5))
coords = jnp.array([
    [1,2],
    [2,3],
    [1,2],
])

I would like to count onto x how many times each of the individual (x,y) coordinates appear in coords. In other words, obtain the output:

Array([[0., 0., 0., 0., 0.],
       [0., 0., 2., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)

I've tried x.at[coords].add(1) and this gives me:

Array([[0., 0., 0., 0., 0.],
       [2., 2., 2., 2., 2.],
       [3., 3., 3., 3., 3.],
       [1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0.]], dtype=float32)

I understand what it's doing, but not how to make it do the thing I want.

There's this related question[1], but I haven't been able to use it to solve my problem.

[1] Update JAX array based on values in another array


Solution

  • For multiple indices, you should pass a tuple of index arrays:

    x = x.at[coords[:, 0], coords[:, 1]].add(1)
    print(x)
    
    [[0. 0. 0. 0. 0.]
     [0. 0. 2. 0. 0.]
     [0. 0. 0. 1. 0.]
     [0. 0. 0. 0. 0.]
     [0. 0. 0. 0. 0.]]