Search code examples
numpy

Use numpy to mask a column containing only zeros (3D array)


I need to achieve basically the solution in this post but for a higher dimensional array. I have an array of shape (100, 24, 29) with dims corresponding to (timepoint x W x H), and I want to delete any columns (from the third dimension) that contain only zeros. I have tried to adapt the answer from the linked post above in this toy example:

b = np.random.rand(3,3,3)
b[:,:,0] = 0

where b looks like:

array([[[0.        , 0.93285577, 0.25492488],
        [0.        , 0.30008854, 0.04393785],
        [0.        , 0.54639525, 0.91724947]],

       [[0.        , 0.15975869, 0.67710479],
        [0.        , 0.76967775, 0.14067868],
        [0.        , 0.75224997, 0.29507396]],

       [[0.        , 0.0559644 , 0.20334715],
        [0.        , 0.04229135, 0.776371  ],
        [0.        , 0.18207046, 0.80668586]]])

Here's my mask to find the zero-columns:

mask = (b == 0).all(axis= 2) # shape (3,3)

However, the mask does not find any such columns:

array([[False, False, False],
       [False, False, False],
       [False, False, False]])

and applying the mask with b[~mask] gives the following:

array([[0.        , 0.92218547, 0.03236384],
       [0.        , 0.08624437, 0.26430342],
       [0.        , 0.05117766, 0.47295541],
       [0.        , 0.60342235, 0.25399603],
       [0.        , 0.24224256, 0.09755698],
       [0.        , 0.62463761, 0.67945886]])

Why have I failed to mask the zero columns and how can I apply the mask to a 3D array?


Solution

  • The issue is that you use np.all on the axis 2. To find nonzero columns, you should check that slices [:, :, k] are all zero rather than slices [i, j, :]. Here is the corrected code:

    import numpy as np
    
    b = np.random.rand(3,3,3)
    b[:,:,0] = 0
    
    mask = (b == 0).all(axis=(0, 1))
    
    print(b[:, :, ~mask])