Search code examples
pythonarraysnumpyindexingmax-pooling

python numpy maxpool: given an array and indices from argmax, returns max values


suppose I have an array called view:

array([[[[ 7,  9],
         [10, 11]],

        [[19, 18],
         [20, 16]]],


       [[[24,  5],
         [ 6, 10]],

        [[18, 11],
         [45, 12]]]])

as you may know from maxpooling, this is a view of the original input, and the kernel size is 2x2:

[[ 7,  9],  [[19, 18],
 [10, 11]],  [20, 16]]], ....

The goal is to find both max values and their indices. However, argmax only works on single axis, so I need to flatten view, i.e. using flatten=view.reshape(2,2,4):

array([[[ 7,  9, 10, 11], [19, 18, 20, 16]],

       [[24,  5,  6, 10], [18, 11, 45, 12]]])

Now, with the help I get from my previous question, I can find indices of max using inds = flatten.argmax(-1):

array([[3, 2],
       [0, 2]])

and values of max:

i, j = np.indices(flatten.shape[:-1])
flatten[i, j, inds]

>>> array([[11, 20],
           [24, 45]])

The problem
the problem arise when I flatten the view array. Since view array is a view of the original array i.e. view = as_strided(original, newshape, newstrides), so view and original shares the same data. However, reshape breaks it, so any change on view is not reflected on original. This is problematical during backpropagation.

My question
Given the array view and indices ind, I'd like to change max values in view to 1000, without using reshape, or any operation that breaks the 'bond' between view and original. Thanks for any help!!!

reproducible example

import numpy as np
from numpy.lib.stride_tricks import as_strided

original=np.array([[[7,9,19,18],[10,11,20,16]],[[24,5,18,11],[6,10,45,12]]],dtype=np.float64)
view=as_strided(original, shape=(2,1,2,2,2),strides=(64,32*2,8*2,32,8))

I'd like to change max values of each kernel in view to, say, 1000, that can be reflected on original, i.e. if I run view[0,0,0,0,0]=1000, then the first element of both view and original are 1000.


Solution

  • how about this:

    import numpy as np
    view = np.array(
        [[[[ 7,  9],
           [10, 11]],
          [[19, 18],
           [20, 16]]],
         [[[24,  5],
           [ 6, 10]],
          [[18, 11],
           [45, 12]]]]
    )
    # Getting the indices of the max values
    max0 = view.max(-2)
    idx2 = view.argmax(-2)
    idx2 = idx2.reshape(-1, idx2.shape[1])
    max1 = max0.max(-1)
    idx3 = max0.argmax(-1).flatten()
    idx2 = idx2[np.arange(idx3.size), idx3]
    
    idx0 = np.arange(view.shape[0]).repeat(view.shape[1])
    idx1 = np.arange(view.shape[1]).reshape(1, -1).repeat(view.shape[0], 0).flatten()
    
    # Replacing the maximal vlues with 1000
    view[idx0, idx1, idx2, idx3] = 1000
    print(f'view = \n{view}')
    

    output:

    view = 
    [[[[   7    9]
       [  10 1000]]
    
      [[  19   18]
       [1000   16]]]
    
    
     [[[1000    5]
       [   6   10]]
    
      [[  18   11]
       [1000   12]]]]
    

    Basically, idx{n} is the index of the maximal value in the last two dimensions for every matrix contained in the first two dimensions.