Search code examples
pythonarraysnumpymask

How can a NumPy array of booleans be used to remove/filter rows of another NumPy array?


I have one NumPy array like this:

array([[ True],
       [ True],
       [ True],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False]], dtype=bool)

I want to use this array to filter the rows of another array like this:

array([[-0.45556594,  0.46623859],
       [-1.80758847, -0.08109728],
       [-0.9792373 , -0.15958186],
       [ 4.58101272, -0.02224513],
       [-1.64387422, -0.03813   ],
       [-1.8175146 , -0.07419429],
       [-1.15527867, -0.1074057 ],
       [-1.48261467, -0.00875623],
       [ 2.23701103,  0.67834847],
       [ 1.45440669, -0.62921477],
       [-1.13694557,  0.07002631],
       [ 1.0645533 ,  0.21917462],
       [-0.03102173,  0.18059074],
       [-1.16885461, -0.06968157],
       [-0.51789417, -0.05855351],
       [ 4.23881128, -0.30072904],
       [-1.37940507, -0.06478938]])

Applying the filter would result in the following array, with just the first three rows:

array([[-0.45556594,  0.46623859],
       [-1.80758847, -0.08109728],
       [-0.9792373 , -0.15958186]])

How can this be done? When I attempt to do something like B[A], where A is the filter array and B is the other one, I get only the first column.


Solution

  • You are trying to select entire rows, so you will want a 1 dimensional array to use to select. As mentioned in comments you can use numpy.ravel() to straighten out your bool array and apply it to b with:

    b[a.ravel()]
    

    You can also explicitly select the first column of a and apply it to b with:

    b[a[:, 0]])
    

    Test Code:

    a = np.array(
        [[ True],
         [ True],
         [ True],
         [False],
         [False],
         [False]], dtype=bool)
    
    b = np.array(
        [[-0.45556594,  0.46623859],
         [-1.80758847, -0.08109728],
         [-0.9792373 , -0.15958186],
         [ 4.58101272, -0.02224513],
         [-1.64387422, -0.03813   ],
         [-1.37940507, -0.06478938]])
    
    print(b[a.ravel()])
    print(b[a[:, 0]])
    

    Results:

    [[-0.45556594  0.46623859]
     [-1.80758847 -0.08109728]
     [-0.9792373  -0.15958186]]
    
    [[-0.45556594  0.46623859]
     [-1.80758847 -0.08109728]
     [-0.9792373  -0.15958186]]