Search code examples

What does self._compute_output_and_mask_jointly = True do in tf.keras.layers.Masking layer?

tf.keras.layers.Masking layer has _compute_output_and_mask_jointly set to True in its __init__(...), what does this attribute do other than telling what it is doing in its call(...)?

def __init__(self, mask_value=0., **kwargs):
  self._compute_output_and_mask_jointly = True

In addition, the mask has been created and applied in call(...). What is the purpose of compute_mask(...)? Seems redundant.

  def compute_mask(self, inputs, mask=None):
    return tf.reduce_any(tf.not_equal(inputs, self.mask_value), axis=-1)

  def call(self, inputs):
    boolean_mask = tf.reduce_any(
        tf.not_equal(inputs, self.mask_value), axis=-1, keepdims=True)
    outputs = inputs * tf.cast(boolean_mask, inputs.dtype)
    # Compute the mask and outputs simultaneously.
    outputs._keras_mask = tf.squeeze(boolean_mask, axis=-1)  # pylint: disable=protected-access
    return outputs


  • First of all, a hefty, fair warning:

    This is an implementation detail, never use it!

    It may be in fact on the way out.

    Having said that, this is a minor optimization, used by the single layers.Masking class of all layer classes there are. This is part of TensorFlow Keras (as opposed to TensorFlow proper). When this attribute is present and set to True on a layer, the Keras framework assumes that the output mask has been already computed in the __call__ invocation and placed into the KerasTensor._layer_mask attribute, and optimizes out a call to the compute_mask method, both in eager and in graph tracing modes. This is all to it. No magic up to eleven.

    Actually, creating the _layer_mask attribute on the output KerasTensor has the same effect. And you'll indeed avoid a nasty surprise one day by setting neither of these attributes.