Search code examples
tensorflowgenerative-adversarial-networktpu

Using tf.cond() in an estimator model function to train a WGAN on a TPU causes doubled global_step


I’m trying to train a GAN on a TPU, so I’ve been messing around with the TPUEstimator class and accompanying model function to try to implement the WGAN training loop. I’m trying to use tf.cond to merge the two training ops for the TPUEstimatorSpec as so:

opt = tf.cond(
    tf.equal(tf.mod(tf.train.get_or_create_global_step(), 
    CRITIC_UPDATES_PER_GEN_UPDATE+1), CRITIC_UPDATES_PER_GEN_UPDATE+1), 
    lambda: gen_opt, 
    lambda: critic_opt
)

gen_opt and critic_opt are the minimize function of the optimizer I'm using, set to update the global step as well. CRITIC_UPDATES_PER_GEN_UPDATE is a python constant for just that and is part of the WGAN training. I've tried to find a GAN model using tf.cond, but all models use tf.group, which I can't use because you need to optimize the critic many more times than the generator. However, every time I run 100 batches, the global step increases by 200 according to the checkpoint number. Is my model still training correctly, or is tf.cond just not supposed to be used this way to train GANs?


Solution

  • tf.cond is not supposed to be used in this way to train GANs.

    You get 200 because every training step the side effects (like assignment operations) of both the true_fn and false_fn are evaluated. One of the side effects is the global step tf.assign_add operation that both optimizers define.

    Hence, what happens is like

    • Exeuctuion of global_step++ (gen_opt) and global_step++ (critic_op)
    • Evaluation of the condition
    • Execution of true_fn body or false_fn body (depending on the condition).

    If you want to train a GAN using tf.cond, you have to remove all the side operations (like the assignment, hence the definition of the optimization step) from the outside of true_fn/false_fn and declare everything inside them.

    As a reference, you can see this answer about the behaviour of tf.cond: https://stackoverflow.com/a/37064128/2891324