I have a multidimensional NumPy-like array, (I'm using Dask, but this applies to NumPy as Dask mimics that API) that derives from an array of 1592 images:
a
:
array([[[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]],
...,
[[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]])
I want to retain images where the masks have False
entries and get rid of images that are all True
. I can do this with array.all()
as:
mask = a.all(axis=1).all(axis=1)
retain = np.where(mask==False,filenames,None)
#write `retain` to a file to be read by another script
where filenames
is my list of file paths.
However, I don't find a.all(axis=1).all(axis=1)
very satisfactory. This looks to me like I am running over the array twice, when once should be enough. But am I?
Note: a.all(axis=1)
gives:
array([[ True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True],
...,
[False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False]])
and a.all(axis=1).all(axis=1)
gives:
array([ True, False, False, False, True, True, False, False, False,
...,
False, False, False, True, False, False, False])
Can I go from 3-dimensional data to 1-dimensional data more efficiently for this example?
Does
mask = a.all(axis=(1,2))
do the job? It is faster than what you are currently doing. Note that while you are going over an array twice, the second time it is shorter. You are effectively doing
b = a.all(axis=1)
mask = b.all(axis=1)
so the second time you go over a shorter array.
PS: you can simplify your code as follows.
mask = a.all(axis=(1,2))
retain = filenames[~mask]