Search code examples
pythonnumpypytorchmax-pooling

How to accelerate the code below? Implementing maxpool without center element


I know maxpool and I am using it in pytorch. Maxpool with dilated parameter is as belows: maxpool, dilated
Now I want a special form of maxpool, doing maxpool with out the central element. That is the kernel size is 3X3 but the central element should be deleted. Thus the result should come from the rest 8 elements.
Now I am using a for loop, how to accelerate this using numpy or pytorch or anything else?

import numpy as np
from timeit import default_timer as timer


def MaxPool_special(kh, kw, arr):
    """
    to do maxpool without central element
    :param kh:  should always be 3
    :param kw:  should always be 3
    :param arr:  the input array
    :return: arr_res: output array
    """
    h, w = arr.shape[:2]

    arr_res = np.array([[maxpool_ij(i, j, arr, kh, kw) for j in range(w)] for i in range(h)])

    return arr_res


def maxpool_ij(i, j, arr, dh, dw):
    """
    find the maximum value around point(i,j) with dilated parameter
    """
    Mmax = None
    imin, imax = i - dh, i + dh
    jmin, jmax = j - dw, j + dw
    if imin >= 0 and imax < h and jmin >= 0 and jmax < w:
        Mmax = np.max(
            arr[[imin, imin, imin, i, i, imax, imax, imax], [jmin, j, jmax, jmin, jmax, jmin, j, jmax]])
    elif imin < 0 and jmin < 0:
        Mmax = np.max(arr[[i, imax, imax], [jmax, j, jmax]])
    elif imin < 0 and jmax >= w:
        Mmax = np.max(arr[[i, imax, imax], [jmin, jmin, j]])
    elif imax >= h and jmin < 0:
        Mmax = np.max(arr[[imin, imin, i], [j, jmax, jmax]])
    elif imax >= h and jmax >= w:
        Mmax = np.max(arr[[imin, imin, i], [jmin, j, jmin]])
    elif imin < 0:
        Mmax = np.max(arr[[i, i, imax, imax, imax], [jmin, jmax, jmin, j, jmax]])
    elif imax >= h:
        Mmax = np.max(arr[[imin, imin, imin, i, i], [jmin, j, jmax, jmin, jmax]])
    elif jmin < 0:
        Mmax = np.max(arr[[imin, imin, i, imax, imax], [j, jmax, jmax, j, jmax]])
    elif jmax >= w:
        Mmax = np.max(arr[[imin, imin, i, imax, imax], [jmin, j, jmin, jmin, j]])

    assert Mmax, f'Wrong logic above!{imin, imax, jmin, jmax, h, w}'

    return Mmax

#  generate input array
h, w = 400, 500
arr = np.random.randint(0, 256, h * w).reshape(h, w)

tic = timer()
grayPool = MaxPool_special(3, 3, arr)
toc = timer()
print(f'time cost for for-loops: {toc - tic}')

Please help me accelerating this code, thanks!


Solution

  • Using torch.nn.Unfold can implement Maxpool without central element. Examples are as below:

    h, w = 7, 10
    x = torch.arange(0,h*w,dtype=torch.float).reshape(1,1,h,w)
    
    """
    x.shape: torch.Size([1, 1, 7, 10])
    x:
    tensor([[[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
              [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
              [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],
              [30., 31., 32., 33., 34., 35., 36., 37., 38., 39.],
              [40., 41., 42., 43., 44., 45., 46., 47., 48., 49.],
              [50., 51., 52., 53., 54., 55., 56., 57., 58., 59.],
              [60., 61., 62., 63., 64., 65., 66., 67., 68., 69.]]]])
    """
    
    Unfold = torch.nn.Unfold(kernel_size=(3,5), stride=(2,2))
    xUfd = Unfold(x)
    """
    xUfd.shape: torch.Size([1, 15, 9])
    xUfd:
    tensor([[[ 0.,  2.,  4., 20., 22., 24., 40., 42., 44.],
             [ 1.,  3.,  5., 21., 23., 25., 41., 43., 45.],
             [ 2.,  4.,  6., 22., 24., 26., 42., 44., 46.],
             [ 3.,  5.,  7., 23., 25., 27., 43., 45., 47.],
             [ 4.,  6.,  8., 24., 26., 28., 44., 46., 48.],
             [10., 12., 14., 30., 32., 34., 50., 52., 54.],
             [11., 13., 15., 31., 33., 35., 51., 53., 55.],
             [12., 14., 16., 32., 34., 36., 52., 54., 56.],
             [13., 15., 17., 33., 35., 37., 53., 55., 57.],
             [14., 16., 18., 34., 36., 38., 54., 56., 58.],
             [20., 22., 24., 40., 42., 44., 60., 62., 64.],
             [21., 23., 25., 41., 43., 45., 61., 63., 65.],
             [22., 24., 26., 42., 44., 46., 62., 64., 66.],
             [23., 25., 27., 43., 45., 47., 63., 65., 67.],
             [24., 26., 28., 44., 46., 48., 64., 66., 68.]]])
    """
    
    xUfd = xUfd[:,:,[0,1,2,3,5,6,7,8]]
    xUfd = torch.max(xUfd, 2).values.reshape(3,5)
    """
    xUfd.shape: torch.Size([3, 5])
    xUfd:
    tensor([[44., 45., 46., 47., 48.],
            [54., 55., 56., 57., 58.],
            [64., 65., 66., 67., 68.]])
    """