Search code examples
pythontensorflowkeras

How to mask row in Tensorflow without for loop


I want to create a custom Layer for a Tensorflow model but the logic I have uses a for loop, which Tensorflow doesn't like. How can I modify my code to remove the for loop but still achieve the same result?

class CustomMask(tf.keras.layers.Layer):
    def call(self, inputs):
        mask = tf.where(inputs[:, 0] < 0.5, 1, 0)
        for i,m in enumerate(mask):
            if m:
                inputs = inputs[i, 1:].assign(tf.zeros(4, dtype=tf.float32))
            else:
                first = tf.where(inputs[:, 1] >= 0.5, 0, 1)
                assign = tf.multiply(tf.cast(first, tf.float32), inputs[:, 2])
                inputs = inputs[:, 2].assign(assign)

                third = tf.where(inputs[:, 1] >= 0.5, 1, 0)
                assign = tf.multiply(tf.cast(third, tf.float32), inputs[:, 1])
                inputs = inputs[:, 1].assign(assign)

        return inputs

Example input Tensor:

<tf.Variable 'Variable:0' shape=(3, 5) dtype=float32, numpy=
array([[0.8, 0.7, 0.2, 0.6, 0.9],
       [0.8, 0.4, 0.8, 0.3, 0.7],
       [0.3, 0.2, 0.4, 0.3, 0.8]], dtype=float32)>

Corresponding output:

<tf.Variable 'UnreadVariable' shape=(3, 5) dtype=float32, numpy=
array([[0.8, 0.7, 0. , 0.6, 0.9],
       [0.8, 0. , 0.8, 0.3, 0.7],
       [0.3, 0. , 0. , 0. , 0. ]], dtype=float32)>

EDIT: The layer should take an array of shape (batch_size, 5) and if the first value of a row is less than 0.5, set the rest of the row values to 0, otherwise if the 2nd element is above 0.5, set the 3rd element to 0 and if the 3rd element is greater than 0.5, set the 2nd element to 0


Solution

  • Without using any foor loop, ask in comments if it doesn't solve your issue (tested in colab)

    import tensorflow as tf
    
    mask1 = tf.convert_to_tensor([0.0,1.0,1.0,1.0,1.0])
    mask2 = tf.convert_to_tensor([0.0,0.0,1.0,0.0,0.0])
    mask3 = tf.convert_to_tensor([0.0,1.0,0.0,0.0,0.0])
    
    def masking(x):
      mask = tf.ones(x.shape, tf.float32) 
      cond1 = tf.cast(x[0] < 0.5, tf.float32)
      x = tf.multiply(x, tf.subtract(mask, tf.multiply(mask1, cond1)))
      cond2 = tf.cast(x[1] > 0.5, tf.float32)
      x = tf.multiply(x, tf.subtract(mask, tf.multiply(mask2, cond2)))
      cond3 = tf.cast(x[2] > 0.5, tf.float32)
      x = tf.multiply(x, tf.subtract(mask, tf.multiply(mask3, cond3)))
      return x
    
    inputs = tf.convert_to_tensor([[0.8, 0.7, 0.2, 0.6, 0.9],
           [0.8, 0.4, 0.8, 0.3, 0.7],
           [0.3, 0.2, 0.4, 0.3, 0.8]])
    res = tf.vectorized_map(masking, inputs)
    print (res)
    
    tf.Tensor(
    [[0.8 0.7 0.  0.6 0.9]
     [0.8 0.  0.8 0.3 0.7]
     [0.3 0.  0.  0.  0. ]], shape=(3, 5), dtype=float32)
    

    I tested it with

    %timeit tf.map_fn(masking, inputs)
    
    %timeit tf.vectorized_map(masking, inputs)
    

    and the tf.vectorized_map(masking, inputs) get faster when the batch size increase