Search code examples
pythonnumpyargmax

How to find all argmax in ndarray


I have a 2 dimensional NumPy ndarray.

array([[  0.,  20.,  -2.],
   [  2.,   1.,   0.],
   [  4.,   3.,  20.]])

How can I get all indices of the maximum elements? So I would like as output array([0,1],[2,2]).


Solution

  • Use np.argwhere on max-equality mask -

    np.argwhere(a == a.max())
    

    Sample run -

    In [552]: a   # Input array
    Out[552]: 
    array([[  0.,  20.,  -2.],
           [  2.,   1.,   0.],
           [  4.,   3.,  20.]])
    
    In [553]: a == a.max() # Max equality mask
    Out[553]: 
    array([[False,  True, False],
           [False, False, False],
           [False, False,  True]], dtype=bool)
    
    In [554]: np.argwhere(a == a.max()) # array of row, col indices of max-mask
    Out[554]: 
    array([[0, 1],
           [2, 2]])
    

    If you are working with floating point numbers, you might want to use some tolerance there. So, with that consideration, you could use np.isclose that has some default absolute and relative tolerance values. This would replace the earlier a == a.max() part, like so -

    In [555]: np.isclose(a, a.max())
    Out[555]: 
    array([[False,  True, False],
           [False, False, False],
           [False, False,  True]], dtype=bool)