Search code examples
pythondaskpython-xarrayoverlapchunks

Overlapping chunks in Xarray dataset for Kernel operations


I try to run a 9x9 pixel kernel across a large satellite image with a custom filter. One satellite scene has ~ 40 GB and to fit it into my RAM, I'm using xarrays options to chunk my dataset with dask.

My filter includes a check if the kernel is complete (i.e. not missing data at the edge of the image). In that case a NaN is returned to prevent a potential bias (and I don't really care about the edges). I now realized, that this introduces not only NaNs at the edges of the image (expected behaviour), but also along the edges of each chunk, because the chunks don't overlap. dask provides options to create chunks with an overlap, but are there any comparable capabilities in xarray? I found this issue, but it doesn't seem like there has been any progress in this regard.

Some sample code (shortened version of my original code):


import numpy as np
import numba
import math
import xarray as xr


@numba.jit("f4[:,:](f4[:,:],i4)", nopython = True)
def water_anomaly_filter(input_arr, window_size = 9):
    # check if window size is odd
    if window_size%2 == 0:
        raise ValueError("Window size must be odd!")
    
    # prepare an output array with NaNs and the same dtype as the input
    output_arr = np.zeros_like(input_arr)
    output_arr[:] = np.nan
    
    # calculate how many pixels in x and y direction around the center pixel
    # are in the kernel
    pix_dist = math.floor(window_size/2-0.5)
    
    # create a dummy weight matrix
    weights = np.ones((window_size, window_size))
    
    # get the shape of the input array
    xn,yn = input_arr.shape
    
    # iterate over the x axis
    for x in range(xn):
        # determine limits of the kernel in x direction
        xmin = max(0, x - pix_dist)
        xmax = min(xn, x + pix_dist+1)
        
        # iterate over the y axis
        for y in range(yn):
            # determine limits of the kernel in y direction
            ymin = max(0, y - pix_dist)
            ymax = min(yn, y + pix_dist+1)

            # extract data values inside the kernel
            kernel = input_arr[xmin:xmax, ymin:ymax]
            
            # if the kernel is complete (i.e. not at image edge...) and it
            # is not all NaN
            if kernel.shape == weights.shape and not np.isnan(kernel).all():
                # apply the filter. In this example simply keep the original
                # value
                output_arr[x,y] = input_arr[x,y]
                
    return output_arr

def run_water_anomaly_filter_xr(xds, var_prefix = "band", 
                                window_size = 9):
    
    variables = [x for x in list(xds.variables) if x.startswith(var_prefix)]
    
    for var in variables[:2]:
        xds[var].values = water_anomaly_filter(xds[var].values, 
                                               window_size = window_size)
    
    return xds

def create_test_nc():

    data = np.random.randn(1000, 1000).astype(np.float32)

    rows = np.arange(54, 55, 0.001)
    cols = np.arange(10, 11, 0.001)

    ds = xr.Dataset(
        data_vars=dict(
            band_1=(["x", "y"], data)
        ),
        coords=dict(
            lon=(["x"], rows),
            lat=(["y"], cols),
        ),
        attrs=dict(description="Testdata"),
    )

    ds.to_netcdf("test.nc")

if __name__ == "__main__":

    # if required, create test data
    create_test_nc()
    
    # import data
    with xr.open_dataset("test.nc",
                         chunks = {"x": 50, 
                                   "y": 50},
                         
                         ) as xds:   

        xds_2 = xr.map_blocks(run_water_anomaly_filter_xr, 
                              xds,
                              template = xds).compute()

        xds_2["band_1"][:200,:200].plot()

This yields: enter image description here

You can clearly see the rows and columns of NaNs along the edges of each chunk.

I'm happy for any suggestions. I would love to get the overlapping chunks (or any other solution) within xarray, but I'm also open for other solutions.


Solution

  • You can use Dask's map_blocks as follows:

    arr = dask.array.map_overlap(
        water_anomaly_filter, xds.band_1.data, dtype='f4', depth=4, window_size=9
    ).compute()
    da = xr.DataArray(arr, dims=xds.band_1.dims, coords=xds.band_1.coords)
    

    Note that you will likely want to tune depth and window_size for your specific application.