Search code examples
pythondaskpython-xarrayzarr

How to efficiently convert npy to xarray / zarr


I have a 37 GB .npy file that I would like to convert to Zarr store so that I can include coordinate labels. I have code that does this in theory, but I keep running out of memory. I want to use Dask in-between to facilitate doing this in chunks, but I still keep running out of memory.

The data is "thickness maps" for people's femoral cartilage. Each map is a 310x310 float array, and there are 47789 of these maps. So the data shape is (47789, 310, 310).

Step 1: Load the npy file as a memmapped Dask array.

fem_dask = dask.array.from_array(np.load('/Volumes/T7/cartilagenpy20220602/femoral.npy', mmap_mode='r'),
                                 chunks=(300, -1, -1))

Step 2: Make an xarray DataArray over the Dask array, with the desired coordinates. I have several coordinates for the 'map' dimension that come from metadata (a pandas dataframe).

fem_xr = xr.DataArray(fem_dask, dims=['map','x','y'],
                         coords={'patient_id': ('map', metadata['patient_id']),
                                 'side':       ('map', metadata['side'].astype(np.string_)),
                                 'timepoint':  ('map', metadata['timepoint'])
                                })

enter image description here

Step 3: Write to Zarr.

fem_ds = fem_xr.to_dataset(name='femoral')  # Zarr requires Dataset, not DataArray
res = fem_ds.to_zarr('/Volumes/T7/femoral.zarr', 
                     encoding={'femoral': {'dtype': 'float32'}},
                     compute=False)
res.visualize()

See task graph below if desired enter image description here

When I call res.compute(), RAM use quickly climbs out of control. The other python processes, which I think are the Dask workers, seem to be inactive:

enter image description here

But a bit later, they are active -- see that one of those Python processes now has 20 gb RAM and another has 36 gb: enter image description here

Which we can also confirm from the Dask dashboard:

enter image description here

Eventually all the workers get killed and the task errors out. How can I do this in an efficient way that correctly uses Dask, xarray, and Zarr, without running out of RAM (or melting the laptop)?


Solution

  • using threads

    If the dask workers can share threads, your code should just work. If you don't initialize a dask Cluster explicitly, dask.Array will create one with default args, which use processes. This results in the behavior you're seeing. To solve this, explicitly create a cluster using threads:

    # use threads, not processes
    cluster = dask.distributed.LocalCluster(processes=False)
    client = dask.distributed.Client(cluster)
    
    arr = np.load('myarr.npy', mmap_mode='r')
    da = dda.from_array(arr).rechunk(chunks=(100, 310, 310))
    da.to_zarr('myarr.zarr', mode='w')
    

    using processes or distributed workers

    If you're using a cluster which cannot share threads, such as a JobQueue, KubernetesCluster, etc., you can use the following to read the npy file, assuming it's on a networked filesystem or is in some way available to all workers.

    Here's a workflow that creates an empy array from the memory map, then maps the read job using dask.array.map_blocks. The key is the use of the block_info optional keyword, which gives information about the location of the block within the array, which we can use to slice new mmap array objects using dask workers:

    def load_npy_chunk(da, fp, block_info=None, mmap_mode='r'):
        """Load a slice of the .npy array, making use of the block_info kwarg"""
        np_mmap = np.load(fp, mmap_mode=mmap_mode)
        array_location = block_info[0]['array-location']
        dim_slicer = tuple(list(map(lambda x: slice(*x), array_location)))
        return np_mmap[dim_slicer]
    
    def dask_read_npy(fp, chunks=None, mmap_mode='r'):
        """Read metadata by opening the mmap, then send the read job to workers"""
        np_mmap = np.load(fp, mmap_mode=mmap_mode)
        da = dda.empty_like(np_mmap, chunks=chunks)
        return da.map_blocks(load_npy_chunk, fp=fp, mmap_mode=mmap_mode, meta=da)
    

    This works for me on a demo of the same size (you could add the xarray.DataArray creation/formatting step at the end, but the dask ops work fine and worker memory stays below 1GB for me):

    import numpy as np, dask.array as dda, xarray as xr, pandas as pd, dask.distributed
    
    ### insert/import above functions here
    
    # save a large numpy array
    np.save('myarr.npy', np.empty(shape=(47789, 310, 310), dtype=np.float32))
    
    cluster = dask.distributed.LocalCluster()
    client = dask.distributed.Client(cluster)
    
    da = dask_read_npy('myarr.npy', chunks=(300, -1, -1), mmap_mode='r')
    da.to_zarr('myarr.zarr', mode='w')