Search code examples
pythonpython-xarray

Use xr.apply_unfunc on a function that can only take scalars, with an input of multi-dimensional DataArrays


Lets say I have some function f that can only take in scalar values:

def f(a, b):
    # Ensure inputs are scalars
    assert(np.isscalar(a) and np.isscalar(b))
    
    result1 = a + b
    result2 = a * b
    return result1, result2

I also have two 2D xarray DataArrays:

da1 = xr.DataArray(np.random.randn(3, 3), dims=('x', 'y'), name='a')
da2 = xr.DataArray(np.random.randn(3, 3), dims=('x', 'y'), name='b')

I want to apply f() on every index of da1 and da2 together. Essentially, I want to do this:

result1_da = xr.zeros_like(da1)
result2_da = xr.zeros_like(da2)
for xi in da1['x']:
    for yi in da2['y']:
        result1, result2 = f(da1.sel(x = xi, y = yi).item(), 
                             da2.sel(x = xi, y = yi).item())
        result1_da.loc[dict(x=xi, y=yi)] = result1
        result2_da.loc[dict(x=xi, y=yi)] = result2

But without looping. I think I should be able to do this using xr.apply_unfunc. But I can't quite get it to work. If I do this:

xr.apply_ufunc(f, da1, da2)

I get an assertion error (scalars are not being passed in)

assert(np.isscalar(a) and np.isscalar(b) )
AssertionError

I've also tried messing around with input_core_dims and the other xr.apply_unfunc parameters, but I can't get anything to work.


Solution

  • As I understand the description of xarray.apply_ufunc, this method assumes by default that the given function f can act on NumPy arrays. If it cannot – like your function f – we must set vectorize=True. However, this does not work directly, because your function f returns a sequence of two numbers, not a single number, so the return value cannot be placed in a new array. Therefore, I suggest the following:

    result1 = xr.apply_ufunc(lambda a, b: f(a, b)[0], da1, da2, vectorize=True)
    result2 = xr.apply_ufunc(lambda a, b: f(a, b)[1], da1, da2, vectorize=True)
    

    This gives identical results to your for-loop.

    If the function f was more computationally expensive than shown in this example, it could be an issue that the suggested solution makes two calls to f for each pair of values in da1 and da2. In that case, one has to come up with a smart way to store intermediate results, but I guess that depends on the actual f.