Search code examples
pythonarraysnumpydask

numpy array.all() solution for multidimensional array where array.all(axis=1).all(axis=1) gives desired result


I have a multidimensional NumPy-like array, (I'm using Dask, but this applies to NumPy as Dask mimics that API) that derives from an array of 1592 images:

a:

array([[[ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        ...,
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True]],

       ...,

       [[ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        ...,
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False]]])

I want to retain images where the masks have False entries and get rid of images that are all True. I can do this with array.all() as:

mask = a.all(axis=1).all(axis=1)
retain = np.where(mask==False,filenames,None)
#write `retain` to a file to be read by another script

where filenames is my list of file paths.

However, I don't find a.all(axis=1).all(axis=1) very satisfactory. This looks to me like I am running over the array twice, when once should be enough. But am I?

Note: a.all(axis=1) gives:

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True],

       ...,

       [False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False]])

and a.all(axis=1).all(axis=1) gives:

array([ True, False, False, False,  True,  True, False, False, False,

        ...,

       False, False, False,  True, False, False, False])

Can I go from 3-dimensional data to 1-dimensional data more efficiently for this example?


Solution

  • Does

    mask = a.all(axis=(1,2))
    

    do the job? It is faster than what you are currently doing. Note that while you are going over an array twice, the second time it is shorter. You are effectively doing

    b = a.all(axis=1)
    mask = b.all(axis=1)
    

    so the second time you go over a shorter array.


    PS: you can simplify your code as follows.

    mask = a.all(axis=(1,2))
    retain = filenames[~mask]