Search code examples
pythondaskpython-xarraychunkszarr

How to reproject xarray dataset memory efficiently with chunks and dask?


Context: I have a netcdf file that I want to reproject. It is a costly operation, and I am learning how to use dask and zarr to do it efficiently without crashing my RAM.

Code presentation: ds is a 3D xarray dataset (dimensions: time, y, x). This array is in projection EPSG:32607 and I want it in EPSG:3413. To do that, I open the dataset by chunking it with dask, and resample it reproject it with rioxarray. I then save it as a zarr (with either xarray or dask, idk which one is the best).

Problems: I want to do two things:

  1. Have the reprojection work on the chunks without crashing my memory.
  2. Find valid chunk size automatically so I can save the dataset to .zarr through dask (last line of code).

How can I do that ?

Dataset description: Dataset description

Code:

# Fetch the path for essential python scripts, grab only the variable "v"
#ds = xr.open_dataset('Cubes/temp/Cube_32607.nc', chunks=({'time': 100, 'y': 100, 'x': 100})).v

# Example reproducing dataset the size of ds

start_date = np.datetime64('1984-06-01')
end_date = np.datetime64('2020-01-01')
num_dates = 18206
len_x = 333
len_y = 334
# Generate time dimension
date_array = date_range = np.linspace(start_date.astype('datetime64[s]').view('int64'), end_date.astype('datetime64[s]').view('int64'), num_dates).astype('datetime64[s]')
# Generate spatial dimension
xs = np.linspace(486832.5, 566512.5, 333)
ys = np.linspace(6695887.5, 6615967.5, 334)

# Create example data
v = np.random.rand(len(date_array), len(ys), len(xs))
ds = xr.Dataset({'v': (('time', 'y', 'x'), v)},
                 coords={'x': xs, 'y': ys, 'time': date_array}, chunks=({'mid_date': 100, 'y': 100, 'x': 100}))



# Import pyproj to get the CRS
import pyproj

# Attribute a projection to the dataset (currently missing its projection)
ds = ds.rio.write_crs("EPSG:32607", inplace=True)

# Reproject the dataset
ds = ds.rio.reproject("EPSG:3413",resampling=Resampling.bilinear)

# Save it a zarr
ds.to_zarr('Cubes/temp/Reprojected_Cube/Reprojected_Cube.zarr', mode='w', compute=False)

# Trigger the parallel computation to write the data to the Zarr store (Is this line a better way to do it ?)
dask.compute(ds.to_zarr('Cubes/temp/Reprojected_Cube/Reprojected_Cube.zarr', mode='w'))

Solution

  • So as far as I am aware reprojection itself in rioxarray isn't dask compatible so you might run into issues.

    However there are workarounds using rasterios WarpedVRT:

    epsg_to = 4326
    with rasterio.open(f) as src:
        print('Source CRS:' +str(src.crs))
        with WarpedVRT(src,resampling=1,src_crs=src.crs,crs=crs.CRS.from_epsg(epsg_to),warp_mem_limit=12000,warp_extras={'NUM_THREADS':2}) as vrt:
            print('Destination CRS:' +str(vrt.crs))
            ds = rioxarray.open_rasterio(vrt).chunk({'x':1500,'y':1500,'band':1}).to_dataset(name='HLS_Red')
    
    ds = ds.persist()
    wait(ds)
    ds
    

    Using this you would have to open and reproject the netCDF variables individually. You can open an individual netCDF variable using this syntax: 'netcdf:/path/to/file.nc:variable'

    Also see the git issues here and here as well as the github gist where this code comes.