Search code examples
pythonnumpyjax

Cross dimensional segmented operation


Say you have the following a array

>>> a = np.arange(27).reshape((3,3,3))
>>> a
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]]], dtype=int64)

And m, an array that specifies segment ids

>>> m = np.linspace(start=0, stop=6, num=27).astype(int).reshape(a.shape)
>>> m
array([[[0, 0, 0],
        [0, 0, 1],
        [1, 1, 1]],

       [[2, 2, 2],
        [2, 3, 3],
        [3, 3, 3]],

       [[4, 4, 4],
        [4, 5, 5],
        [5, 5, 6]]])

When using JAX and wishing to perform, say, a sum over the scalars in a that share the same id in m, we can rely on jax.ops.segment_sum.

>>> jax.ops.segment_sum(data=a.ravel(), segment_ids=m.ravel())
Array([10, 26, 42, 75, 78, 94, 26], dtype=int64)

Note that I had to resort to numpy.ndarray.ravel since ~.segment_sum assumes m to indicate the segments of data along its leading axis.


Q1 : Can you confirm there is no better approach, either with or without JAX ?

Q2 : How would one then build n, an array that results from the replacement of the ids with the just-performed sums ? Note that I am not interested in non-vectorized approaches such as numpy.where.

>>> n
array([[[10, 10, 10],
        [10, 10, 26],
        [26, 26, 26]],

       [[42, 42, 42],
        [42, 75, 75],
        [75, 75, 75]],

       [[78, 78, 78],
        [78, 94, 94],
        [94, 94, 26]]], dtype=int64)

Solution

  • The segment_sum operation is somewhat more specialized than what you're asking about. In the case you describe, I would use ndarray.at directly:

    sums = jnp.zeros(m.max() + 1).at[m].add(a)
    print(sums[m])
    
    [[[10. 10. 10.]
      [10. 10. 26.]
      [26. 26. 26.]]
    
     [[42. 42. 42.]
      [42. 75. 75.]
      [75. 75. 75.]]
    
     [[78. 78. 78.]
      [78. 94. 94.]
      [94. 94. 26.]]]
    

    This will also work when the segments are non-adjacent.