Search code examples
numpywhere-clause

Numpy where using multi-dimensional arrays


I have an array of length N of 2x2 arrays, and an array of N Booleans corresponding to good 2x2 arrays. I want to replace the bad 2x2s with np.zeros((2,2)).

I tried np.where but it failed. Here's an example with N=5:

>>> a = np.arange(20).reshape((5,2,2))
>>> a
array([[[ 0,  1],
        [ 2,  3]],

       [[ 4,  5],
        [ 6,  7]],

       [[ 8,  9],
        [10, 11]],

       [[12, 13],
        [14, 15]],

       [[16, 17],
        [18, 19]]])
>>> good = np.array([ x%4 != 3 for x in range(5) ])
>>> good
array([ True,  True,  True, False,  True])
>>> np.where(good, a, np.zeros((2,2)))
ValueError: operands could not be broadcast together with shapes (5,) (5,2,2) (2,2)

I was expecting:

>>> a
array([[[ 0,  1],
        [ 2,  3]],

       [[ 4,  5],
        [ 6,  7]],

       [[ 8,  9],
        [10, 11]],

       [[ 0,  0],
        [ 0,  0]],

       [[16, 17],
        [18, 19]]])

Is there some Numpy way of doing this? I finally resorted to a conditional list comprehension. Thanks


Solution

  • It's easier to use indexing for this:

    a[~good] = 0.0
    

    See documentation and examples here: