Search code examples
pythonnumpyargmax

np.argmax that returns tuple


I have a matrix A of shape (n, m, s). At each position in the 0th axis, I need the position corresponding to the maximum in the (m, s)-shaped array.

For example:

np.random.seed(1)
A = np.random.randint(0, 10, size=[10, 3, 3])

A[0] is:

array([[5, 8, 9],
       [5, 0, 0],
       [1, 7, 6]])

I want to obtain (0, 2), i.e. the position of 9 here.

I would love to do aa = A.argmax(), such that aa.shape = (10, 2), and aa[0] = [0, 2]

How can I achieve this?


Solution

  • Using np.unravel_index with a list comprehension:

    out = [np.unravel_index(np.argmax(block, axis=None), block.shape) for block in array]
    

    where block will be the 3x3 (m x s) shaped array in each turn.

    This gives a list with 10 (n) entries:

    [(0, 2), (0, 0), (0, 1), (0, 1), (2, 0), (1, 2), (0, 0), (1, 1), (1, 0), (1, 2)]
    

    You can convert this to a numpy array (of desired shape (n, 2)):

    aa = np.array(out)
    

    to get:

    array([[0, 2],
           [0, 0],
           [0, 1],
           [0, 1],
           [2, 0],
           [1, 2],
           [0, 0],
           [1, 1],
           [1, 0],
           [1, 2]], dtype=int64)