Search code examples
pythonnumpymultidimensional-arrayarray-broadcastingmatrix-indexing

How to broadcast numpy indexing along batch dimensions?


For example, np.array([[1,2],[3,4]])[np.triu_indices(2)] has shape (3,), being a flattened list of the upper triangular entries. However, if I have a batch of 2x2 matrices:

foo = np.repeat(np.array([[[1,2],[3,4]]]), 30, axis=0)

and I want to obtain the upper triangular indices of each matrix, the naive thing to try would be:

foo[:,np.triu_indices(2)]

However, this object is actually of shape (30,2,3,2) (as opposed to (30,3) that we might expect if we had extracted the upper triangular entries batch-wise.

How can we broadcast tuple indexing along the batch dimensions?


Solution

  • Get the tuples and use those to index into the last two dims -

    r,c = np.triu_indices(2)
    out = foo[:,r,c]
    

    Alternatively, one-liner with Ellipsis that works for both 3D and 2D arrays -

    foo[(Ellipsis,)+np.triu_indices(2)]
    

    It will work for 2D arrays similarly -

    out = foo[r,c] # foo as 2D input array
    

    Masking way

    3D array case

    We can also use a mask for a masking based way -

    foo[:,~np.tri(2,k=-1, dtype=bool)]
    

    2D array case

    foo[~np.tri(2,k=-1, dtype=bool)]