Search code examples
pythonarraysnumpymatrix

Numpy select rows based on condition


I want to remove rows from a two dimensional numpy array using a condition on the values of the first row.

I am able to do this with regular python using two loops, but I would like to do it more efficiently with numpy, e.g. with numpy.where

I have been trying various things with numpy.where and numpy.delete but I struggle with applying a condition to the first column only.

Here is an example where I only want to keep the rows where the first value of each row is 6.

Input:

[[0,4],
 [0,5],
 [3,5],
 [6,8],
 [9,1],
 [6,1]]

Output:

[[6,8],
 [6,1]]

Solution

  • Use a boolean mask:

    mask = (z[:, 0] == 6)
    z[mask, :]
    

    This is much more efficient than np.where because you can use the boolean mask directly, without having the overhead of converting it to an array of indices first.

    One liner:

    z[z[:, 0] == 6, :]