Search code examples
pythonarraysnumpymasking

How to mask rows of a 2D numpy matrix by values in 1D list?


I have a 2D numpy array that looks like this:

a = np.array([[0, 1, 2, 3, 4],
              [0, 1, 2, 3, 4],
              [0, 1, 2, 3, 4],
              [0, 1, 2, 3, 4],
              [0, 1, 2, 3, 4]])

And a 1D list that looks like this:

b = [4, 3, 2, 3, 4]

I'd like to mask my 2D array (a) according to which values in a given row are less than the corresponding value in my 1D list (b). For example, row a[0] would be masked according to which values in that row are less than the value at b[0]; the same with row a[1] and the value at b[1], and so on...

What I hope to get is a 2D array of booleans:

mask_bools = [[True, True, True, True, False],
              [True, True, True, False, False],
              [True, True, False, False, False],
              [True, True, True, False, False],
              [True, True, True, True, False]]

I have a silly way to achieve this with a loop:

mask_bools = []
for i in range(len(b)):
    mask_bools.append(np.ma.masked_less(a[i], b[i]).mask)
mask_bools = np.array(mask_bools)

But I feel like there must be a better/faster way to do this that takes better advantage of numpy functionality. Any ideas? Thanks!


Solution

  • Try broadcasting less than:

    a < b[:, None]
    
    [[ True  True  True  True False]
     [ True  True  True False False]
     [ True  True False False False]
     [ True  True  True False False]
     [ True  True  True  True False]]
    
    import numpy as np
    
    a = np.array([[0, 1, 2, 3, 4],
                  [0, 1, 2, 3, 4],
                  [0, 1, 2, 3, 4],
                  [0, 1, 2, 3, 4],
                  [0, 1, 2, 3, 4]])
    
    b = np.array([4, 3, 2, 3, 4])
    
    c = a < b[:, None]
    
    # Test equality with expected output
    mask_bools = np.array([[True, True, True, True, False],
                           [True, True, True, False, False],
                           [True, True, False, False, False],
                           [True, True, True, False, False],
                           [True, True, True, True, False]])
    
    print((c == mask_bools).all().all())  # True