Search code examples
arraysnumpyfor-loopmatrixwhere-clause

Use of loops (example, for) in the conditional arguments of numpy.where


I have a matrix "x" and an array "index". I just want to find the position (row number) of the "index" array in the matrix "x". Example:

x = np.array([[0, 0],
       [1, 0],
       [2, 0],
       [0, 1],
       [1, 1],
       [2, 1],
       [0, 2],
       [1, 2],
       [2, 2],
       [0, 3],
       [1, 3],
       [2, 3],])

index = [2,1]

Here if I use the code: np.where((x[:,0]==index[0]) & (x[:,1]==index[1]))[0] it is working.

But if I have the matrix "x" with N number of columns (instead of 2), I have to use loop inside the np.where arguments. I tried this: np.where((for b in range(2):(x[:,b]==index[b]) & (x[:,b]==index[b])))[0]

Then it shows "invalid syntax" error. Can you please help me regarding this? Thanks in advance.


Solution

  • The where is only as good as its argument, which is evaluated in full before being passed to the where function:

    In [292]: np.where((x[:,0]==index[0]) & (x[:,1]==index[1]))[0]
    Out[292]: array([5], dtype=int64)
    

    The condition is a boolean array:

    In [293]: (x[:,0]==index[0]) & (x[:,1]==index[1])
    Out[293]: 
    array([False, False, False, False, False,  True, False, False, False,
           False, False, False])
    

    Looks like you tried to create a for loop:

    for b in range(2):
          (x[:,b]==index[b]) & (x[:,b]==index[b])
    

    Using that as argument is not valid python. You could create a function that does

    def foo(x,index):
        res = []
        for b in ....
             res.append(...)
        return res
    

    But a simpler syntax is list comprehension:

    In [294]: [x[:,i]==index[i] for i in range(2)]
    Out[294]: 
    [array([False, False,  True, False, False,  True, False, False,  True,
            False, False,  True]),
     array([False, False, False,  True,  True,  True, False, False, False,
            False, False, False])]
    

    and the arrays can be combined with a np.all:

    In [295]: np.all([x[:,i]==index[i] for i in range(2)], axis=0)
    Out[295]: 
    array([False, False, False, False, False,  True, False, False, False,
           False, False, False])
    

    But as others show you don't need to iterate. Let the (n,2) x broadcast against the (2,) index:

    In [296]: x==index
    Out[296]: 
    array([[False, False],
           [False, False],
           [ True, False],
           [False,  True],
           [False,  True],
           [ True,  True],
           [False, False],
           [False, False],
           [ True, False],
           [False, False],
           [False, False],
           [ True, False]])
    
    In [297]: (x==index).all(axis=1)
    Out[297]: 
    array([False, False, False, False, False,  True, False, False, False,
           False, False, False])