Search code examples
tensorflowselectconditional-statementswhere-clausemask

Removing certain rows from tensor in tensorflow without using tf.RaggedTensor


Given tensor data

   [[[ 0.,  0.],
    [ 1.,  1.],
    [-1., -1.]],

   [[-1., -1.],
    [ 4.,  4.],
    [ 5.,  5.]]]

I want to remove [-1,-1] and get

   [[[ 0.,  0.],
    [ 1.,  1.]],

   [[ 4.,  4.],
    [ 5.,  5.]]]

How to get the above without using ragged feature in tensorflow?


Solution

  • You can try this:

    x = tf.constant(
          [[[ 0.,  0.],
          [ 1.,  1.],
          [-1., -2.]],
    
         [[-1., -2.],
          [ 4.,  4.],
          [ 5.,  5.]]])
    
    mask = tf.math.not_equal(x, np.array([-1, -1]))
    
    result = tf.boolean_mask(x, mask)
    shape = tf.shape(x)
    result = tf.reshape(result, (shape[0], -1, shape[2]))