Search code examples
pythonnumpyargmax

Calculating the argmax from one array and using to get values from another


I am trying to get the argmax from one ndaray and use it to get values from another ndarray, but I am doing something wrong.

ndvi_array = np.random.randint(0, 255, size=(4, 1, 100, 100))
image_array = np.random.randint(0, 255, size=(4, 12, 100, 100))
ndvi_argmax = ndvi_array.argmax(0)
print(f"NDVI argmax shape: {ndvi_argmax.shape}")
zipped = tuple(zip(range(len(ndvi_argmax)), ndvi_argmax))
result = image_array[zipped]
print(f"Result share: {result.shape}")

I get the following error:

only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

How can I get an array of shape (1,12,100,100) with the maximum values?


Solution

  • >>> result = np.take_along_axis(
    ...     image_array,
    ...     ndvi_array.argmax(axis=0, keepdims=True),
    ...     axis=0,
        )
    
    >>> print(f"{result.shape = }")
    result.shape = (1, 12, 100, 100)