Search code examples
pythonnumpymachine-learningargmax

How to build argsecondmax in Numpy


In Numpy, argmax is already defined, but I need argsecondmax, which is basically the second maxima. How can I do this, I'm a bit confused?


Solution

  • Finding Nth largest indices

    An efficient one could use np.argparition that skips sorting and simply parititions, which when sliced would give us the required indices. We would also generalize it to find Nth largest one along a specified axis or a global one (similar to ndarray.argmax()), like so -

    def argNmax(a, N, axis=None):
        if axis is None:
            return np.argpartition(a.ravel(), -N)[-N]
        else:
            return np.take(np.argpartition(a, -N, axis=axis), -N, axis=axis)
    

    Sample runs -

    In [66]: a
    Out[66]: 
    array([[908, 770, 258, 534],
           [399, 376, 808, 750],
           [655, 654, 825, 355]])
    
    In [67]: argNmax(a, N=2, axis=0)
    Out[67]: array([2, 2, 1, 0])
    
    In [68]: argNmax(a, N=2, axis=1)
    Out[68]: array([1, 3, 0])
    
    In [69]: argNmax(a, N=2) # global second largest index
    Out[69]: 10
    

    Finding Nth smallest indices

    Extending this to find the Nth smallest one along an axis or globally, we would have -

    def argNmin(a, N, axis=None):
        if axis is None:
            return np.argpartition(a.ravel(), N-1)[N-1]
        else:
            return np.take(np.argpartition(a, N-1, axis=axis), N-1, axis=axis)
    

    Sample runs -

    In [105]: a
    Out[105]: 
    array([[908, 770, 258, 534],
           [399, 376, 808, 750],
           [655, 654, 825, 355]])
    
    In [106]: argNmin(a, N=2, axis=0)
    Out[106]: array([2, 2, 1, 0])
    
    In [107]: argNmin(a, N=2, axis=1)
    Out[107]: array([3, 0, 1])
    
    In [108]: argNmin(a, N=2)
    Out[108]: 11
    

    Timings

    To give a perspective on the benefit on using argpartition over actual sorting with argsort as shown in @pythonic833's post, here's a quick runtime test on the global argmax version -

    In [70]: a = np.random.randint(0,99999,(1000,1000))
    
    In [72]: %timeit np.argsort(a)[-2] # @pythonic833's soln
    10 loops, best of 3: 40.6 ms per loop
    
    In [73]: %timeit argNmax(a, N=2)
    100 loops, best of 3: 2.12 ms per loop