Search code examples
pythondaskpython-xarraydask-distributedzarr

Concurrently write xarray datasets to zarr - how to efficiently scale with dask distributed


TLDR:

How can I efficiently use dask-distributed to write a number of dask-backed xarray datasets to a zarr store on AWS S3?

Details:

I have a workflow that takes a list of raster datasets on S3 and generates a dask-array backed xarray dataset.

I need to iterate over a number of groups, where for each group the workflow gets the raster datasets which belong to the group and generates the corresponding xarray dataset.

Now I want to write the data from the datasets into a zarr storage on S3 (the same store, just using the group parameter).

This is how a pseudo code for sequential processing could look like:

client = Client(...) # using a distributed cluster

zarr_store = fsspec.get_mapper("s3://bucket/key.zarr")

for group_select in groups:
    
    xr_dataset = get_dataset_for_group(group_select)
    
    # totally unnecessary, just to illustrate that this is a lazy dataset, nothing has been loaded yet
    assert dask.is_dask_collection(xr_dataset)
    
    xr_dataset.to_zarr(zarr_store, group=group_select)

This works very well, once to_zarr is executed, the data is loaded and stored on S3 with the tasks running in parallel.


Now I'd like to run this in parallel using dask.distribuited. This is what I've tried and what issues I've encountered:

1. using .to_zarr(..., compute=False) to collect a list of delayed tasks

This works in principle, but is quite slow. Creating a task takes around 3-4 seconds, and I need to run this 100+ times, taking 4-5 minutes before any computation is actually started.

2. Wrapping it into dask.delayed

This speeds up the creation of tasks immensely, but the write to the zarr store isn't split among workers, rather the worker processing the task is gathering all the data once the loading tasks are finished and writes it to zarr.

3. Wrapping to_zarr in a custom function and passing it to client.submit

This looked to be the most promising option. I've just wrapped the to_zarr call in a custom function, that can be called from a worker:

def dump(ds, target, group=None):
    with worker_client() as client:
        ds.to_zarr(store=target, group=group)  
    return True

Doing this with a worker_client puts the writing tasks back to the scheduler and solves the issue I've had above with dask.delayed.

However, when I submit this function repeatedly (I need to do this 100+ times) along the lines of

futures = [client.submit(dump, x, target, g) for x,g in zip(datasets, groups)]

I quickly overwhelm the scheduler with tasks to process.

The only obvious solution for this I can think of, is to split up the datasets in batches and only start a new one once the previous one is finished. But isn't there a more elegant solution? Or is there a built-in functionality in dask (distributed)?


Solution

  • In my experience/environment, it is easy to overwhelm the scheduler with too many tasks (and also too many workers to coordinate), so splitting things into batches typically works.

    To create a moving queue of work, you can use as_completed, submitting/adding tasks each time another task is completed. See these related answers: 1 and 2.