Search code examples
pythonarraysnumpyfilterabstract-syntax-tree

np.where with arbitrary number of conditions


Problem

This question: Numpy where function multiple conditions asks how to use np.where with two conditions. This answer suggests to use the & operator between conditions, which works if we have a low number of conditions which can be typed. This answer suggests using the np.logical_and, which can take only two arguments.

This thread: Numpy "where" with multiple conditions also discusses multiple conditions for np.where, but the number of conditions are known in advance.

I am looking for a way to evaluate an np.where expression without knowing the number of conditions in advance.


Reproducible setup

I have a 2D array:

arr = \
np.array([[1,2,3,4],
          [4,5,6,7],
          [9,8,7,6],
          [0,1,0,1],
          [9,7,6,5]])

Select the rows which have, for example, index 1 element larger than 5, index 2 element larger than 3. To do that, I do:

res = arr[np.where((arr[:,1]>5) & (arr[:,2]>4))]

res is then:

array([[9, 8, 7, 6],
       [9, 7, 6, 5]])

as expected.

But what if I have these conditions as lists? The above example would be:

cols = [1,2] # arbitrary length list
tholds = [5,4] # arbitrary length list

These two lists are unknown length in advance, but they have the same length.

How can I get res using the cols and tholds lists?


What I have tried

Use ast.literal_eval to define:

filterstring = "&".join([f"(pdist[:,{col}]>{th})" for col, th in zip(cols,tholds)])

which evaluates to (pdist[:,1]>5)&(pdist[:,2]>4), ie what we had above within np.where() when the conditions are typed out manually.

However, ast.literal_eval(f"np.where({filterstring})") gives an error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-269-1aaff20de82f> in <module>()
----> 1 ast.literal_eval(f"np.where({filterstring})")

3 frames
/usr/lib/python3.7/ast.py in _convert_num(node)
     53         elif isinstance(node, Num):
     54             return node.n
---> 55         raise ValueError('malformed node or string: ' + repr(node))
     56     def _convert_signed_num(node):
     57         if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):

ValueError: malformed node or string: <_ast.Call object at 0x7f41daa21f10>

So this did not work. This answer to the question ast.literal_eval() malformed node or string while converting a string with list of array()s confirms that this is not the right approach.


EDIT:

The suggestion to use np.wheres in a succession works fine for this particular example, but is not really what I look for. I would want to call np.where once, not multiple times evaluating one condition only.


Solution

  • You can do this by combining a reordered view of the columns of your array (which is a fancy way of saying "use a list of indexes") with a broadcast comparison, reduced over rows with np.all

    >>> arr[np.where(np.all(arr[:,cols] > thds, axis=1))]
    array([[9, 8, 7, 6],
           [9, 7, 6, 5]])
    

    As your first link indicates (and as mentioned in the Note at the top of the documentation for np.where), there is actually no need for np.where in this case; it's only slowing things down. You can use a boolean list to slice a Numpy array, so you don't need to change the boolean list to a list of indexes. Since np.all, like the & operator, returns a Numpy array of boolean values, there is also no need for np.asarray or np.nonzero (as suggested in the aforementioned note):

    >>> arr[np.all(arr[:,cols] > thds, axis=1)]
    array([[9, 8, 7, 6],
           [9, 7, 6, 5]])