Search code examples
arraysnumpyconvolution

Using numpy `as_strided` function to create patches, tiles, rolling or sliding windows of arbitrary dimension


Spent a while this morning looking for a generalized question to point duplicates to for questions about as_strided and/or how to make generalized window functions. There seem to be a lot of questions on how to (safely) create patches, sliding windows, rolling windows, tiles, or views onto an array for machine learning, convolution, image processing and/or numerical integration.

I'm looking for a generalized function that can accept a window, step and axis parameter and return an as_strided view for over arbitrary dimensions. I will give my answer below, but I'm interested if anyone can make a more efficient method, as I'm not sure using np.squeeze() is the best method, I'm not sure my assert statements make the function safe enough to write to the resulting view, and I'm not sure how to handle the edge case of axis not being in ascending order.

DUE DILIGENCE

The most generalized function I can find is sklearn.feature_extraction.image.extract_patches written by @eickenberg (as well as the apparently equivalent skimage.util.view_as_windows), but those are not well documented on the net, and can't do windows over fewer axes than there are in the original array (for example, this question asks for a window of a certain size over just one axis). Also often questions want a numpy only answer.

@Divakar created a generalized numpy function for 1-d inputs here, but higher-dimension inputs require a bit more care. I've made a bare bones 2D window over 3d input method, but it's not very extensible.


Solution

  • EDIT JAN 2020: Changed the iterable return from a list to a generator to save memory.

    EDIT OCT 2020: Put the generator in a separate function, since mixing generators and return statements doesn't work intiutively.

    Here's the recipe I have so far:

    def window_nd(a, window, steps = None, axis = None, gen_data = False):
            """
            Create a windowed view over `n`-dimensional input that uses an 
            `m`-dimensional window, with `m <= n`
            
            Parameters
            -------------
            a : Array-like
                The array to create the view on
                
            window : tuple or int
                If int, the size of the window in `axis`, or in all dimensions if 
                `axis == None`
                
                If tuple, the shape of the desired window.  `window.size` must be:
                    equal to `len(axis)` if `axis != None`, else 
                    equal to `len(a.shape)`, or 
                    1
                    
            steps : tuple, int or None
                The offset between consecutive windows in desired dimension
                If None, offset is one in all dimensions
                If int, the offset for all windows over `axis`
                If tuple, the steps along each `axis`.  
                    `len(steps)` must me equal to `len(axis)`
        
            axis : tuple, int or None
                The axes over which to apply the window
                If None, apply over all dimensions
                if tuple or int, the dimensions over which to apply the window
    
            gen_data : boolean
                returns data needed for a generator
        
            Returns
            -------
            
            a_view : ndarray
                A windowed view on the input array `a`, or `a, wshp`, where `whsp` is the window shape needed for creating the generator
                
            """
            ashp = np.array(a.shape)
            
            if axis != None:
                axs = np.array(axis, ndmin = 1)
                assert np.all(np.in1d(axs, np.arange(ashp.size))), "Axes out of range"
            else:
                axs = np.arange(ashp.size)
                
            window = np.array(window, ndmin = 1)
            assert (window.size == axs.size) | (window.size == 1), "Window dims and axes don't match"
            wshp = ashp.copy()
            wshp[axs] = window
            assert np.all(wshp <= ashp), "Window is bigger than input array in axes"
            
            stp = np.ones_like(ashp)
            if steps:
                steps = np.array(steps, ndmin = 1)
                assert np.all(steps > 0), "Only positive steps allowed"
                assert (steps.size == axs.size) | (steps.size == 1), "Steps and axes don't match"
                stp[axs] = steps
        
            astr = np.array(a.strides)
            
            shape = tuple((ashp - wshp) // stp + 1) + tuple(wshp)
            strides = tuple(astr * stp) + tuple(astr)
            
            as_strided = np.lib.stride_tricks.as_strided
            a_view = np.squeeze(as_strided(a, 
                                         shape = shape, 
                                         strides = strides))
            if gen_data :
                return a_view, shape[:-wshp.size]
            else:
                return a_view
    
    def window_gen(a, window, **kwargs):
        #Same docstring as above, returns a generator
        _ = kwargs.pop(gen_data, False)
        a_view, shp = window_nd(a, window, gen_data  = True, **kwargs)
        for idx in np.ndindex(shp):
            yield a_view[idx]
    

    Some test cases:

    a = np.arange(1000).reshape(10,10,10)
    
    window_nd(a, 4).shape # sliding (4x4x4) window
    Out: (7, 7, 7, 4, 4, 4)
    
    window_nd(a, 2, 2).shape # (2x2x2) blocks
    Out: (5, 5, 5, 2, 2, 2)
    
    window_nd(a, 2, 1, 0).shape # sliding window of width 2 over axis 0
    Out: (9, 2, 10, 10)
    
    window_nd(a, 2, 2, (0,1)).shape # tiled (2x2) windows over first and second axes
    Out: (5, 5, 2, 2, 10)
    
    window_nd(a,(4,3,2)).shape  # arbitrary sliding window
    Out: (7, 8, 9, 4, 3, 2)
    
    window_nd(a,(4,3,2),(1,5,2),(0,2,1)).shape #arbitrary windows, steps and axis
    Out: (7, 5, 2, 4, 2, 3) # note shape[-3:] != window as axes are out of order