Search code examples
pythonnumpyfilterstack

Filtering a numpy.dstack


I have a dstack like this:

import numpy as np
a = np.array((1,2,6))
b = np.array((2,3,4))
c = np.array((8,3,0))
stack = np.dstack((a,b,c))
print(stack)
#[[[1 2 8]
  #[2 3 3]
  #[6 4 0]]]

and I want to filter out the lists where the 2 element is less then 1.

Something like this:

new_list = []

for i in stack:
    for d in i[:,2]:
        if d>=1:
            new_list.append(d)
print(new_list) # [8,3]

Doing this only the 2 element is added, but I would like to have all the row, like this:

#[[[1 2 8]
  #[2 3 3]]]

And if I append(i) the result is not the desired one either.


Solution

  • You don't need a loop, you can do it with slicing

    print(stack[stack[:,2] >= 1])
    

    Output

    [[1 2 8]
     [2 3 3]]
    

    If you need it as

    [[[1 2 8]]
     [[2 3 3]]]
    

    you can reshape the result

    stack = stack[stack[:,2] >= 1]
    shape = stack.shape
    print(stack[stack[:,2] >= 1].reshape((shape[0], 1, shape[1])))