Search code examples
pythonnumpyindexingmulti-index

Multi-indexing with tuples


I have a multi-dimensional np.array. I know the shape of the first N dimension and of the last M dimensions. E.g.,

>>> n = (3,4,5)
>>> m = (6,)
>>> a = np.ones(n + m)
>>> a.shape
(3, 4, 5, 6)

Using tuples as indices allows for quick indexing like of the first N dimensions, like

>>> i = (1,1,2)
>>> a[i].shape
(6,)

Using list does not give me the same result I need

>>> i = [1,1,2]
>>> a[i].shape
(3, 4, 5, 6)

But I am having trouble doing multi-indexing (both to retrieve / assign values). For example,

>>> i = (1,1,2)
>>> j = (2,2,2)

I need to pass something like

>>> a[[i, j]]

and get an output shape of (2, 6).

Instead I get

>>> a[[i, j]].shape
(2, 3, 4, 5, 6)

or

>>> a[(i, j)].shape
(3, 5, 6)

I can always loop or change how I index things (like using np.reshape and np.unravel_index), but is there a more pythonic way to achieve what I need?

EDIT I'd need this for any number of indices, e.g.,

>>> i = (1,1,2)
>>> j = (2,2,2)
>>> k = (0,0,0)
...

Solution

  • Consider a list of indices:

    idx = [
        (1, 1, 2),  # Your i
        (2, 2, 2),  # Your j
        (0, 0, 0),  # Your k
        (1, 2, 1),  # ... 
        (2, 0, 1),  # extend as necessary
    ]
    

    and your array a with shape (3, 4, 5, 6).

    When you write out = a[idx], numpy interprets it like this:

    out = np.array([
        [a[1], a[1], a[2]],
        [a[2], a[2], a[2]],
        [a[0], a[0], a[0]],
        [a[1], a[2], a[1]],
        [a[2], a[0], a[1]],
    ])
    

    In which, e.g. a[0], is just the first subarray of a, thus with shape (4, 5, 6)!

    As a result, you're left with what represents an array of shape (5, 3) (the shape of the index), containing (4, 5, 6) subarrays of a! (... final result being (5, 3, 4, 5, 6), or np.shape(idx) + a.shape[1:]).


    Instead what you want is the following:

    out = np.array([
        a[1, 1, 2],
        a[2, 2, 2],
        a[0, 0, 0],
        a[1, 2, 1],
        a[2, 0, 1],
    ])
    

    The way to accomplish it "vectorially" in numpy is the following:

    out = a[
        [1, 2, 0, 1, 2],  # [idx[0][0], idx[1][0], idx[2][0], ...]
        [1, 2, 0, 2, 0],  # [idx[0][1], idx[1][1], idx[2][1], ...]
        [2, 2, 0, 1, 1]   # [idx[0][2], idx[1][2], idx[2][2], ...]
    ]
    

    That behavior is documented in the indexing guide:

    Advanced indices always are broadcast and iterated as one:

    result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M],
                               ..., ind_N[i_1, ..., i_M]]
    

    To transform the original idx into such an indexer, you can use the tuple(zip(*idx)) trick.

    Numpy's indexing system is magically flexible, but the cost of that flexibility is that these kinds of "simple" tasks become unintuitive... at least in my opinion ;)