Search code examples
pythonnumpymultidimensional-arraymaxargmax

python numpy argmax to max in multidimensional array


I have the following code:

import numpy as np
sample = np.random.random((10,10,3))
argmax_indices = np.argmax(sample, axis=2)

i.e. I take the argmax along axis=2 and it gives me a (10,10) matrix. Now, I want to assign these indices value 0. For this, I want to index the sample array. I tried:

max_values = sample[argmax_indices]

but it doesn't work. I want something like

max_values = sample[argmax_indices]
sample[argmax_indices] = 0

I simply validate by checking that max_values - np.max(sample, axis=2) should give a zero matrix of shape (10,10). Any help will be appreciated.


Solution

  • Here's one approach -

    m,n = sample.shape[:2]
    I,J = np.ogrid[:m,:n]
    max_values = sample[I,J, argmax_indices]
    sample[I,J, argmax_indices] = 0
    

    Sample step-by-step run

    1) Sample input array :

    In [261]: a = np.random.randint(0,9,(2,2,3))
    
    In [262]: a
    Out[262]: 
    array([[[8, 4, 6],
            [7, 6, 2]],
    
           [[1, 8, 1],
            [4, 6, 4]]])
    

    2) Get the argmax indices along axis=2 :

    In [263]: idx = a.argmax(axis=2)
    

    3) Get the shape and arrays for indexing into first two dims :

    In [264]: m,n = a.shape[:2]
    
    In [265]: I,J = np.ogrid[:m,:n]
    

    4) Index using I, J and idx for storing the max values using advanced-indexing :

    In [267]: max_values = a[I,J,idx]
    
    In [268]: max_values
    Out[268]: 
    array([[8, 7],
           [8, 6]])
    

    5) Verify that we are getting an all zeros array after subtracting np.max(a,axis=2) from max_values :

    In [306]: max_values - np.max(a, axis=2)
    Out[306]: 
    array([[0, 0],
           [0, 0]])
    

    6) Again using advanced-indexing assign those places as zeros and do one more level of visual verification :

    In [269]: a[I,J,idx] = 0
    
    In [270]: a
    Out[270]: 
    array([[[0, 4, 6], # <=== Compare this against the original version
            [0, 6, 2]],
    
           [[1, 0, 1],
            [4, 0, 4]]])