I’m trying to train a GAN on a TPU, so I’ve been messing around with the TPUEstimator class and accompanying model function to try to implement the WGAN training loop. I’m trying to use tf.cond
to merge the two training ops for the TPUEstimatorSpec as so:
opt = tf.cond(
tf.equal(tf.mod(tf.train.get_or_create_global_step(),
CRITIC_UPDATES_PER_GEN_UPDATE+1), CRITIC_UPDATES_PER_GEN_UPDATE+1),
lambda: gen_opt,
lambda: critic_opt
)
gen_opt
and critic_opt
are the minimize function of the optimizer I'm using, set to update the global step as well. CRITIC_UPDATES_PER_GEN_UPDATE
is a python constant for just that and is part of the WGAN training. I've tried to find a GAN model using tf.cond
, but all models use tf.group
, which I can't use because you need to optimize the critic many more times than the generator.
However, every time I run 100 batches, the global step increases by 200 according to the checkpoint number. Is my model still training correctly, or is tf.cond
just not supposed to be used this way to train GANs?
tf.cond
is not supposed to be used in this way to train GANs.
You get 200 because every training step the side effects (like assignment operations) of both the true_fn
and false_fn
are evaluated. One of the side effects is the global step tf.assign_add
operation that both optimizers define.
Hence, what happens is like
global_step++ (gen_opt)
and global_step++ (critic_op)
true_fn
body or false_fn
body (depending on the condition).If you want to train a GAN using tf.cond
, you have to remove all the side operations (like the assignment, hence the definition of the optimization step) from the outside of true_fn
/false_fn
and declare everything inside them.
As a reference, you can see this answer about the behaviour of tf.cond
: https://stackoverflow.com/a/37064128/2891324