I have a matrix A
of shape (n, m, s)
. At each position in the 0th axis, I need the position corresponding to the maximum in the (m, s)
-shaped array.
For example:
np.random.seed(1)
A = np.random.randint(0, 10, size=[10, 3, 3])
A[0]
is:
array([[5, 8, 9],
[5, 0, 0],
[1, 7, 6]])
I want to obtain (0, 2)
, i.e. the position of 9
here.
I would love to do
aa = A.argmax()
, such that aa.shape = (10, 2)
, and aa[0] = [0, 2]
How can I achieve this?
Using np.unravel_index
with a list comprehension:
out = [np.unravel_index(np.argmax(block, axis=None), block.shape) for block in array]
where block
will be the 3x3
(m x s
) shaped array in each turn.
This gives a list with 10 (n
) entries:
[(0, 2), (0, 0), (0, 1), (0, 1), (2, 0), (1, 2), (0, 0), (1, 1), (1, 0), (1, 2)]
You can convert this to a numpy array (of desired shape (n, 2)
):
aa = np.array(out)
to get:
array([[0, 2],
[0, 0],
[0, 1],
[0, 1],
[2, 0],
[1, 2],
[0, 0],
[1, 1],
[1, 0],
[1, 2]], dtype=int64)