Search code examples
pythondaskpython-xarray

Dealing with a very large xarray dataset: loading slices consuming too much time


I have a very large netcdf dataset consisting of daily chunks data from April 1985 to April 2024. As the arrays are divided into daily chunks, I often open them by using ds = xr.open_mfdataset(*.nc). The entire dataset has up to 1.07TB, which is way way far than I can handle loading into memory:

enter image description here

By slicing over lat/lon coordinates ds.sel(latitude=y, longitude=x, method='nearest') I get a single pixel along my timeseries, which is now far lighter than the original dataset and allows me to perform the analysis I need:

enter image description here

However, even though the sliced dataset is now very light, it stills takes so much time for it to get loaded into memory (more than 1h) with ds.load(). This would not be a big deal if I didn't need to perform this operation more 100,000 times, which would take incredible 10 years to finish!

I don't have a powerful machine, but it's decent enough for performing the tasks I need. Although I was expecting this task to took some time, I really wish to finish it before becoming a dad. Besides going for a more powerful machine (which I think will still not reduce the amount of required time to the order of days), is there any way I can try to optimize this task?


Solution

  • The problem is that the sliced result still requires many operations. On the screenshot you shared, the "Dask graph" field still has a large number of layers and chunks.

    Possible ways forward include:

    • using chunks kwarg to select the most appropriate chunking. Rough guess:
    import dask
    import xarray as xr
    
    
    ds = xr.open_mfdataset(*.nc, chunks={"time": 15_000})
    slice = ds.sel(latitude=y, longitude=x, method='nearest')
    slice.compute()
    
    • paying a fixed cost in terms of time/disk space and reshaping the data into the format/chunks most compatible with your workflow;

    • trying dask.optimize, rough pseudocode:

    import dask
    import xarray as xr
    
    
    ds = xr.open_mfdataset(*.nc)
    slice = ds.sel(latitude=y, longitude=x, method='nearest')
    (optimized_slice,) = dask.optimize(slice)
    optimized_slice.compute()