Search code examples
pythontensorflowmachine-learningdeep-learninggradienttape

Checking condition in call method of custom layer using tf.cond()


I'm implementing a custom layer in tensorflow 2.x . My requirement is such that, the program should check a condition before returning the output.

class SimpleRNN_cell(tf.keras.layers.Layer):
    def __init__(self, M1, M2, fi=tf.nn.tanh, disp_name=True):
        super(SimpleRNN_cell, self).__init__()
        pass        
    def call(self, X, hidden_state, return_state=True):
        y = tf.constant(5)
        if return_state == True:
            return y, self.h
        else:
            return y

My question is: should I continue using the present code (assuming that the tape.gradient(Loss, self.trainable_weights) will work fine) or should I use tf.cond(). Also, if possible please explain where to use tf.cond() and where not to. I haven't found much content on this topic.


Solution

  • tf.cond is only relevant when performing conditional evaluation based on data in the differentiable computation graph. (https://www.tensorflow.org/api_docs/python/tf/cond) This was especially necessary in TF 1.0 with graph mode being the default. For eager mode the GradientTape system allows to also do conditional data flow with the python constructs such as if ...: (https://www.tensorflow.org/guide/autodiff#control_flow)

    Yet for just providing different behavior based on configuration parameters, that don't depend on data from the computational graph and are fixed during the model runtime, using simple python if statements is correct.