TLDR:
How can I efficiently use dask-distributed
to write a number of dask
-backed xarray
datasets to a zarr
store on AWS S3?
Details:
I have a workflow that takes a list of raster datasets on S3 and generates a dask-array backed xarray dataset.
I need to iterate over a number of groups, where for each group the workflow gets the raster datasets which belong to the group and generates the corresponding xarray dataset.
Now I want to write the data from the datasets into a zarr storage on S3 (the same store, just using the group
parameter).
This is how a pseudo code for sequential processing could look like:
client = Client(...) # using a distributed cluster
zarr_store = fsspec.get_mapper("s3://bucket/key.zarr")
for group_select in groups:
xr_dataset = get_dataset_for_group(group_select)
# totally unnecessary, just to illustrate that this is a lazy dataset, nothing has been loaded yet
assert dask.is_dask_collection(xr_dataset)
xr_dataset.to_zarr(zarr_store, group=group_select)
This works very well, once to_zarr
is executed, the data is loaded and stored on S3 with the tasks running in parallel.
Now I'd like to run this in parallel using dask.distribuited
. This is what I've tried and what issues I've encountered:
1. using .to_zarr(..., compute=False)
to collect a list of delayed tasks
This works in principle, but is quite slow. Creating a task takes around 3-4 seconds, and I need to run this 100+ times, taking 4-5 minutes before any computation is actually started.
2. Wrapping it into dask.delayed
This speeds up the creation of tasks immensely, but the write to the zarr store isn't split among workers, rather the worker processing the task is gathering all the data once the loading tasks are finished and writes it to zarr.
3. Wrapping to_zarr
in a custom function and passing it to client.submit
This looked to be the most promising option. I've just wrapped the to_zarr
call in a custom function, that can be called from a worker:
def dump(ds, target, group=None):
with worker_client() as client:
ds.to_zarr(store=target, group=group)
return True
Doing this with a worker_client
puts the writing tasks back to the scheduler and solves the issue I've had above with dask.delayed
.
However, when I submit this function repeatedly (I need to do this 100+ times) along the lines of
futures = [client.submit(dump, x, target, g) for x,g in zip(datasets, groups)]
I quickly overwhelm the scheduler with tasks to process.
The only obvious solution for this I can think of, is to split up the datasets in batches and only start a new one once the previous one is finished. But isn't there a more elegant solution? Or is there a built-in functionality in dask (distributed)?
In my experience/environment, it is easy to overwhelm the scheduler with too many tasks (and also too many workers to coordinate), so splitting things into batches typically works.
To create a moving queue of work, you can use as_completed
, submitting/adding tasks each time another task is completed. See these related answers: 1 and 2.