Search code examples
pythonarraysdaskdask-delayed

Dask looping over library function call


Goal

I would like to parallelize a loop with dask that uses a library function inside the loop. This function, mhw.detect(), calculates some statistics on a slice of a numpy array. None of the slices of the array depend on the other slices, so I was hoping that dask could be used to compute them in parallel and store them all in the same output array.

Code

The flow of the code I am working on is:

import numpy as np
import marineHeatWaves as mhw
from dask import delayed

# Create fake input data
lat_size, long_size = 100, 100
data = np.random.random_integers(0, 30, size=(10_000, long_size, lat_size))  # size = (time, longitude, latitude)
time = np.arange(730_000, 740_000)  # time in ordinal days

# Initialize an empty array to hold the output
output_array = np.empty(data.shape)

# loop through each pixel in the data array
for idx_lat in range(lat_size):
    for idx_long in range(long_size):
        # Extract a slice of data
        data_slice = data[:, idx_lat, idx_long]
        # Use the library function to calculate the stats for the pixel
        # `library_output` is a dictionary that has a numpy array inside it
        _, library_output = delayed(mhw.detect)(time, data_slice)
        # Update the output array with the calculated values from the library
        output_array[:, idx_lat, idx_long] = library_output['seas']

Previous efforts

When I run this code I get the error TypeError: Delayed objects of unspecified length are not iterable. Another stack overflow post discusses this issue and resolves the issue by converting the output of the delayed function to a delayed object. However, because I didn't create the output object myself I am not sure if I can convert it to a delayed object.

I've also tried wrapping the last line in da.from_delayed(), as in output_array[:, idx_lat, idx_long] = da.from_delayed(library_output['seas']) and initalizing the output_array with da.empty(data.shape). I get the same error, though, since I think the code doesn't make it past the line with the library function delayed(mhw.detect)(time, data_slice).

Is it possible to parallelize this? Is this approach of asking dask to compute all the slices in parallel and put them together in an output array even a reasonable approach?

Full Traceback

TypeError                                 Traceback (most recent call last)
/home/rwegener/mhw-ocetrac-census/notebooks/ejoliver_subset_MUR.ipynb Cell 44' in <cell line: 10>()
     13 data_slice = data[:, idx_lat, idx_long]
     14 # Use the library function to calculate the stats for the pixel
---> 15 _, point_clim = delayed(mhw.detect)(time_ordinal, data_slice)
     16 # Update the output array with the calculated values from the library
     17 output_array[:, idx_lat, idx_long] = point_clim['seas']

File ~/.conda/envs/dask/lib/python3.10/site-packages/dask/delayed.py:581, in Delayed.__iter__(self)
    579 def __iter__(self):
    580     if self._length is None:
--> 581         raise TypeError("Delayed objects of unspecified length are not iterable")
    582     for i in range(self._length):
    583         yield self[i]

TypeError: Delayed objects of unspecified length are not iterable

Update

Using .apply_along_axis() as suggested:

# Create fake input data
lat_size, long_size = 100, 100
data = np.random.randint(0, 30, size=(10_000, long_size, lat_size))  # size = (time, longitude, latitude)
data = dask.array.from_array(data, chunks=(-1, 100, 100))
time = np.arange(730_000, 740_000)  # time in ordinal days

# Initialize an empty array to hold the output
output_array = np.empty(data.shape)

# define a wrapper to rearrange arguments
def func1d(arr, time, shape=(10000,)):
   print(arr.shape)
   return mhw.detect(time, arr)

res = dask.array.apply_along_axis(func1d, 0, data, time=time)

With the output:

(1,)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/homes/metogra/rwegener/mhw-ocetrac-census/notebooks/ejoliver_subset_MUR.ipynb Cell 48' in <cell line: 15>()
     12    print(arr.shape)
     13    return mhw.detect(time, arr)
---> 15 res = dask.array.apply_along_axis(func1d, 0, data, time=time)

File ~/.conda/envs/dask/lib/python3.10/site-packages/dask/array/routines.py:508, in apply_along_axis(func1d, axis, arr, dtype, shape, *args, **kwargs)
    506 if shape is None or dtype is None:
    507     test_data = np.ones((1,), dtype=arr.dtype)
--> 508     test_result = np.array(func1d(test_data, *args, **kwargs))
    509     if shape is None:
    510         shape = test_result.shape

/homes/metogra/rwegener/mhw-ocetrac-census/notebooks/ejoliver_subset_MUR.ipynb Cell 48' in func1d(arr, time, shape)
     11 def func1d(arr, time, shape=(10000,)):
     12    print(arr.shape)
---> 13    return mhw.detect(time, arr)

File ~/.conda/envs/dask/lib/python3.10/site-packages/marineHeatWaves-0.28-py3.10.egg/marineHeatWaves.py:280, in detect(t, temp, climatologyPeriod, pctile, windowHalfWidth, smoothPercentile, smoothPercentileWidth, minDuration, joinAcrossGaps, maxGap, maxPadLength, coldSpells, alternateClimatology, Ly)
    278     tt = tt[tt>=0] # Reject indices "before" the first element
    279     tt = tt[tt<TClim] # Reject indices "after" the last element
--> 280     thresh_climYear[d-1] = np.nanpercentile(tempClim[tt.astype(int)], pctile)
    281     seas_climYear[d-1] = np.nanmean(tempClim[tt.astype(int)])
    282 # Special case for Feb 29

IndexError: index 115 is out of bounds for axis 0 with size 1

Solution

  • Rather than using delayed, this seems like a good case for dask.array.

    You can create the dask array by partitioning the numpy array:

    da = dask.array.from_array(output_array, chunks=(-1, 10, 10))
    

    Now you can call mhw.detect using dask.array.map_blocks alongside np.apply_along_axis within each block:

    # define a wrapper to rearrange arguments
    def func1d(arr, time):
       return mhw.detect(time, arr)
    
    def block_func(block, **kwargs):
        return np.apply_along_axis(func1d, 0, block, **kwargs)
    
    res = data.map_blocks(block_func, meta=data, time=time)
    res = res.compute()