Search code examples
tensorflowconv-neural-networkbackpropagation

Proper use of tf.cond in CNN


I have a question regarding the use of tf.cond. I am trying to run 2 images through a CNN and only backprop using the lower cross entropy loss value. The code is as follows:

train_cross_entropy = tf.cond(train_cross_entropy1 < train_cross_entropy2, lambda: train_cross_entropy1, lambda: train_cross_entropy2)

Using this train_cross_entropy is just as slow as writing train_cross_entropy = train_cross_entropy1 + train_cross_entropy2

Which suggests to me that it is backpropping through both parts of the graph instead of just the one. I would hope that it would be almost as fast as writing train_cross_entropy = train_cross_entropy1

It would be greatly appreciated if anybody had any ideas on how to accomplish this! Thanks.


Solution

  • I just had to move the gradient calculation inside the tf.cond like so:

        def f1(): 
            grads = tf.gradients(train_cross_entropy1, var_list, 
                                 stop_gradients=[train_cross_entropy2])
            return grads
        def f2(): 
            grads = tf.gradients(train_cross_entropy2, var_list, 
                                 stop_gradients=[train_cross_entropy1])
            return grads
    
        gradients = tf.cond(train_cross_entropy1 < train_cross_entropy2, f1, f2)
    

    And then I can apply gradients later on.