Search code examples
pythonnumpy

Using np.argpartition to index values in a multidimensional array


I have an array like this one:

>>> a = np.arange(60).reshape([3,4,5])
>>> a
array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]])

And I want to retrieve the top k values along one of the dimensions. For example's sake I'll choose k=2 and along the middle dimension.

I've tried using argpartition and it seems to do the right thing, but I'm having trouble using it's output to retrieve the values from the original array. Here's how I'm using argpartition:

>>> indices = np.argpartition(a, 2, axis=1)
>>> indices
array([[[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]]])

>>> indices[:,-2:,:]
array([[[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]]])

But I've not been able to get the values out by slicing using these indices.

>>> a[:,indices[:,-2:,:],:].shape
(3, 3, 2, 5, 5)

I am expecting to see an array of shape (3,2,5) (as I'm looking for the top-2 along the middle axis) that I imagine looks something like this:

>>> magic_output
array([[[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]])

How can I access the values using the indices from argpartition?


Solution

  • Well np.argpartition gets the smallest k indices. So, to get the top k indices, we need to use negated input array along the desired axis. Then, we need to use these indices to index into that axis using NumPy's advanced-indexing and have the desired output.

    Thus, the implementation would be -

    k = 2
    m,n = a.shape[0], a.shape[2]
    idx = np.argpartition(-a,k,axis=1)[:,k-1::-1]
    out = a[np.arange(m)[:,None,None], idx, np.arange(n)]
    

    Sample run -

    1) Input :

    In [180]: a
    Out[180]: 
    array([[[ 0,  1,  2,  3,  4],
            [ 5,  6,  7,  8,  9],
            [10, 11, 12, 13, 14],
            [15, 16, 17, 18, 19]],
    
           [[20, 21, 22, 23, 24],
            [25, 26, 27, 28, 29],
            [30, 31, 32, 33, 34],
            [35, 36, 37, 38, 39]],
    
           [[40, 41, 42, 43, 44],
            [45, 46, 47, 48, 49],
            [50, 51, 52, 53, 54],
            [55, 56, 57, 58, 59]]])
    

    2) Proposed codes :

    In [206]: k = 2
         ...: m,n = a.shape[0], a.shape[2]
         ...: idx = np.argpartition(-a,k,axis=1)[:,k-1::-1]
         ...: out = a[np.arange(m)[:,None,None], idx, np.arange(n)]
         ...: 
    

    3) Check back intermediate results and output :

    In [207]: idx
    Out[207]: 
    array([[[2, 2, 2, 2, 2],
            [3, 3, 3, 3, 3]],
    
           [[2, 2, 2, 2, 2],
            [3, 3, 3, 3, 3]],
    
           [[2, 2, 2, 2, 2],
            [3, 3, 3, 3, 3]]])
    
    In [208]: out
    Out[208]: 
    array([[[10, 11, 12, 13, 14],
            [15, 16, 17, 18, 19]],
    
           [[30, 31, 32, 33, 34],
            [35, 36, 37, 38, 39]],
    
           [[50, 51, 52, 53, 54],
            [55, 56, 57, 58, 59]]])