Search code examples
pythonnumpymatrix-indexing

How to index row elements of a Matrix with a Matrix of indices for each row?


I have a Matrix of indices I e.g.

I = np.array([[1, 0, 2], [2, 1, 0]])

The index at i-th row selects an element from another Matrix M in the i-th row.

So having M e.g.

M = np.array([[6, 7, 8], [9, 10, 11])

M[I] should select:

[[7, 6, 8], [11, 10, 9]]

I could have:

I1 = np.repeat(np.arange(0, I.shape[0]), I.shape[1])
I2 = np.ravel(I)
Result = M[I1, I2].reshape(I.shape)

but this looks very complicated and I am looking for a more elegant solution. Preferably without flattening and reshaping.

In the example I used numpy, but I am actually using jax. So if there is a more efficient solution in jax, feel free to share.


Solution

  • In [108]: I = np.array([[1, 0, 2], [2, 1, 0]])
         ...: M = np.array([[6, 7, 8], [9, 10, 11]])
         ...: 
         ...: I,M
    

    I had to add a ']' to M.

    Out[108]: 
    (array([[1, 0, 2],
            [2, 1, 0]]),
     array([[ 6,  7,  8],
            [ 9, 10, 11]]))
    

    Advanced indexing with broadcasting:

    In [110]: M[np.arange(2)[:,None],I]
    Out[110]: 
    array([[ 7,  6,  8],
           [11, 10,  9]])
    

    THe first index has shape (2,1) which pairs with the (2,3) shape of I to select a (2,3) block of values.