Search code examples
pythonnumpymaxnumpy-ndarray

How do I get indices of N maximum values in a NumPy array?


NumPy proposes a way to get the index of the maximum value of an array via np.argmax.

I would like a similar thing, but returning the indexes of the N maximum values.

For instance, if I have an array, [1, 3, 2, 4, 5], then nargmax(array, n=3) would return the indices [4, 3, 1] which correspond to the elements [5, 4, 3].


Solution

  • Newer NumPy versions (1.8 and up) have a function called argpartition for this. To get the indices of the four largest elements, do

    >>> a = np.array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])
    >>> a
    array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])
    
    >>> ind = np.argpartition(a, -4)[-4:]
    >>> ind
    array([1, 5, 8, 0])
    
    >>> top4 = a[ind]
    >>> top4
    array([4, 9, 6, 9])
    

    Unlike argsort, this function runs in linear time in the worst case, but the returned indices are not sorted, as can be seen from the result of evaluating a[ind]. If you need that too, sort them afterwards:

    >>> ind[np.argsort(a[ind])]
    array([1, 8, 5, 0])
    

    To get the top-k elements in sorted order in this way takes O(n + k log k) time.