I have a question regarding the use of tf.cond. I am trying to run 2 images through a CNN and only backprop using the lower cross entropy loss value. The code is as follows:
train_cross_entropy = tf.cond(train_cross_entropy1 < train_cross_entropy2,
lambda: train_cross_entropy1,
lambda: train_cross_entropy2)
Using this train_cross_entropy is just as slow as writing
train_cross_entropy = train_cross_entropy1 + train_cross_entropy2
Which suggests to me that it is backpropping through both parts of the graph instead of just the one.
I would hope that it would be almost as fast as writing
train_cross_entropy = train_cross_entropy1
It would be greatly appreciated if anybody had any ideas on how to accomplish this! Thanks.
I just had to move the gradient calculation inside the tf.cond like so:
def f1():
grads = tf.gradients(train_cross_entropy1, var_list,
stop_gradients=[train_cross_entropy2])
return grads
def f2():
grads = tf.gradients(train_cross_entropy2, var_list,
stop_gradients=[train_cross_entropy1])
return grads
gradients = tf.cond(train_cross_entropy1 < train_cross_entropy2, f1, f2)
And then I can apply gradients later on.