Search code examples
pythonarraysnumpyindexingargmax

Get indices of numpy.argmax elements over an axis


I have N-dimensional matrix which contains the values for a function with N parameters. Each parameter has a discrete number of values. I need to maximize the function over all parameters but one, resulting in a one-dimensional vector of size equal to the number of values of the non-maximized parameter. I also need to save which values are taken by the other parameters.

To do so I wanted to iteratively apply numpy.max over different axes to reduce the dimensionality of the matrix to find what I need. The final vector will then depend on just the parameter I left out.

I'm however having trouble finding the original indices of the final elements (which contain the information about the values taken by the other parameters). I though about using numpy.argmax in the same way as numpy.max but I can't obtain back the original indices.

An example of what I'm trying is:

x = [[[1,2],[0,1]],[[3,4],[6,7]]]
args = np.argmax(x, 0)

This returns

[[1 1]
 [1 1]]

Which means that argmax is selecting the elements (2,1,4,7) within the original matrix. But how to get their indices? I tried unravel_index, using the args directly as an index for matrix x, a bunch of functions from numpy to index with no success.

Using numpy.where is not a solution since the input matrix may have equal values inside, so I would not be able to discern from different original values.


Solution

  • x.argmax(0) gives the indexes along the 1st axis for the maximum values. Use np.indices to generate the indices for the other axis.

    x = np.array([[[1,2],[0,1]],[[3,4],[6,7]]])
    x.argmax(0)
        array([[1, 1],
               [1, 1]])
    a1, a2 = np.indices((2,2))
    (x.argmax(0),a1,a2)
        (array([[1, 1],
                [1, 1]]),
         array([[0, 0],
                [1, 1]]),
         array([[0, 1],
                [0, 1]]))
    
    
    x[x.argmax(0),a1,a2]
        array([[3, 4],
               [6, 7]])
    
    x[a1,x.argmax(1),a2] 
        array([[1, 2],
               [6, 7]])
    
    x[a1,a2,x.argmax(2)] 
        array([[2, 1],
               [4, 7]])
    

    If x has other dimensions, generate a1, and a2 appropriately.

    The official documentation does not say much about how to use argmax, but earlier SO threads have discussed it. I got this general idea from Using numpy.argmax() on multidimensional arrays