Search code examples
pythonarraysdask

Slicing using broadcasting in dask


I want to get the values of array x where the index is taken from contents of another array filt. So the output needs to be some of the values from x, equivalent to the length of filt.

import dask.array as da
import numpy as np

x = np.linspace(0,5,10).reshape(10,1)

# shape is (1,3)
filt = np.array([[2,3,5]])

# get the x values coinciding to the 2nd, 3rd and 5th index
x[filt]

This works in numpy. How would I get it to work in dask? Currently errors out with an AssertionError.

x = da.linspace(0,5,10).reshape(10,1)
filt = da.array([[2,3,5]])

x[filt]

Solution

  • Your ilissue is that your filter array has an extra dimension! You want the slice to look like x[[2,3,5]], so you do that by setting filt to simply the one-dimensional [2,3,5]. In numpy, if you give this slicer extra dimensions, the output just gains an extra dimension too, but this behavior doesn't happen in dask.

    x = da.linspace(0,5,10).reshape(10,1)
    filt = da.array([2,3,5])
    
    x[filt].compute()  # output: array([[1.11111111], [1.66666667], [2.77777778]])