I have
x = jnp.zeros((5,5))
coords = jnp.array([
[1,2],
[2,3],
[1,2],
])
I would like to count onto x
how many times each of the individual (x,y) coordinates appear in coords
. In other words, obtain the output:
Array([[0., 0., 0., 0., 0.],
[0., 0., 2., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float32)
I've tried x.at[coords].add(1)
and this gives me:
Array([[0., 0., 0., 0., 0.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0.]], dtype=float32)
I understand what it's doing, but not how to make it do the thing I want.
There's this related question[1], but I haven't been able to use it to solve my problem.
For multiple indices, you should pass a tuple of index arrays:
x = x.at[coords[:, 0], coords[:, 1]].add(1)
print(x)
[[0. 0. 0. 0. 0.]
[0. 0. 2. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]