Search code examples
pythonnumpyoptimizationconv-neural-networkmax-pooling

How to optimize this MaxPool2d implementation


I made some implementations of MaxPool2d(Running correctly, comparing with a pytorch). When testing this on a mnist dataset, this function(updateOutput) takes a very long time to complete. How to optimize this code using numpy?

class MaxPool2d(Module):
    def __init__(self, kernel_size):
        super(MaxPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.gradInput = None

    def updateOutput(self, input):
        #print("MaxPool updateOutput")
        #start_time = time.time()
        kernel = self.kernel_size
        poolH = input.shape[2] // kernel
        poolW = input.shape[3] // kernel
        self.output = np.zeros((input.shape[0], 
                                input.shape[1], 
                                poolH,
                                poolW))
        self.index = np.zeros((input.shape[0],
                                    input.shape[1],
                                    poolH,
                                    poolW,
                                    2), 
                                    dtype='int32')

        for i in range(input.shape[0]):
            for j in range(input.shape[1]):
                for k in range(0, input.shape[2] - kernel+1, kernel):
                    for m in range(0, input.shape[3] - kernel+1, kernel):
                        M = input[i, j, k : k+kernel, m : m+kernel]
                        self.output[i, j, k // kernel, m // kernel] = M.max()
                        self.index[i, j, k // kernel, m // kernel] = np.array(np.unravel_index(M.argmax(), M.shape)) + np.array((k, m))

        #print(f"time: {time.time() - start_time:.3f}s")
        return self.output

input shape = (batch_size, n_input_channels, h, w)

output shape = (batch_size, n_output_channels, h // kern_size, w // kern_size)


Solution

  • For clarity I've simplified your example by removing batch size and channels dimensions. Most of time is spent on calculation of M.max(). I've created benchmark function update_output_b to do this loop with constant array of ones.

    import time
    import numpy as np
    
    def timeit(cycles):
        def timed(func):
            def wrapper(*args, **kwargs):
                start_t = time.time()
                for _ in range(cycles):
                    func(*args, **kwargs)
                t = (time.time() - start_t) / cycles
                print(f'{func.__name__} mean execution time: {t:.3f}s')
    
            return wrapper
        return timed
    
    @timeit(100)
    def update_output_b(input, kernel):
        ones = np.ones((kernel, kernel))
    
        pool_h = input.shape[0] // kernel
        pool_w = input.shape[1] // kernel
        output = np.zeros((pool_h, pool_w))
    
        for i in range(0, input.shape[0] - kernel + 1, kernel):
            for j in range(0, input.shape[1] - kernel + 1, kernel):
                output[i // kernel, j // kernel] = ones.max()
    
        return output
    
    in_arr = np.random.rand(3001, 200)
    update_output_b(in_arr, 3)
    

    Its output is update_output_b mean execution time: 0.277s as it doesn't use numpy fully vectorized operations. When it is possible, you should always prefere native numpy functions over loops.

    In addition, using slices of input array slow execution as access to continuous memory is in most cases faster.

    @timeit(100)
    def update_output_1(input, kernel):
        pool_h = input.shape[0] // kernel
        pool_w = input.shape[1] // kernel
        output = np.zeros((pool_h, pool_w))
    
        for i in range(0, input.shape[0] - kernel + 1, kernel):
            for j in range(0, input.shape[1] - kernel + 1, kernel):
                M = input[i : i + kernel, j : j + kernel]
                output[i // kernel, j // kernel] = M.max()
    
        return output
    
    update_output_1(in_arr, 3)
    

    Code returns update_output_1 mean execution time: 0.332s (+55ms comparing to previous one)

    I've added vectorized code bellow. It works ~20x faster (update_output_2 mean execution time: 0.015s), however it is probably far from optimal.

    @timeit(100)
    def update_output_2(input, kernel):
        pool_h = input.shape[0] // kernel
        pool_w = input.shape[1] // kernel
        input_h = pool_h * kernel
        input_w = pool_w * kernel
    
        # crop input
        output = input[:input_h, :input_w]
        # calculate max along second axis
        output = output.reshape((-1, kernel))
        output = output.max(axis=1)
        # calculate max along first axis
        output = output.reshape((pool_h, kernel, pool_w))
        output = output.max(axis=1)
    
        return output
    
    update_output_2(in_arr, 3)
    

    It generates output in 3 steps:

    • Cropping input to size divisible by kernel
    • Calculating max along second axis (it reduce offsets between slices in first axis)
    • Calculating max along first axis

    Edit:

    I've added modifications for retrieving indexes of max values. However, you should check index arithmetics as I've only tested it on a random array.

    It calculate output_indices along second axis in ech window and then uses output_indices_selector to select maximum along second one.

    def update_output_3(input, kernel):
        pool_h = input.shape[0] // kernel
        pool_w = input.shape[1] // kernel
        input_h = pool_h * kernel
        input_w = pool_w * kernel
    
        # crop input
        output = input[:input_h, :input_w]
    
        # calculate max along second axis
        output_tmp = output.reshape((-1, kernel))
        output_indices = output_tmp.argmax(axis=1)
        output_indices += np.arange(output_indices.shape[0]) * kernel
        output_indices = np.unravel_index(output_indices, output.shape)
        output_tmp = output[output_indices]
    
        # calculate max along first axis
        output_tmp = output_tmp.reshape((pool_h, kernel, pool_w))
        output_indices_selector = (kernel * pool_w * np.arange(pool_h).reshape(pool_h, 1))
        output_indices_selector = output_indices_selector.repeat(pool_w, axis=1)
        output_indices_selector += pool_w * output_tmp.argmax(axis=1)
        output_indices_selector += np.arange(pool_w)
        output_indices_selector = output_indices_selector.flatten()
    
        output_indices = (output_indices[0][output_indices_selector],
                          output_indices[1][output_indices_selector])
        output = output[output_indices].reshape(pool_h, pool_w)
    
        return output, output_indices