Search code examples
pythontensorflowkerasbatch-normalizationgenerative-adversarial-network

Why is tf.keras BatchNormalization causing GANs to produce nonsense loss and accuracy?


Background:

I've been getting unusual losses and accuracies when training GANs with batch normalization layers in the discriminator using tf.keras. GANs have an optimal objective function value of log(4), which occurs when the discriminator is completely unable to discern real samples from fakes and hence predicts 0.5 for all samples. When I include BatchNormalization layers in my discriminator, both the generator and the discriminator achieve near perfect scores (high accuracy, low loss), which is impossible in an adversarial setting.

Without BatchNorm:

This figure shows the losses (y) per epoch (x) when BN is not used. Note that occasional values below the theoretical minimum are due to the training being an iterative process. This figure shows the accuracies when BN is not used, which settle at about 50% each. Both of these figures show reasonable values.

With BatchNorm:

This figure shows the losses (y) per epoch (x) when BN is used. See how the GAN objective, which shouldn't fall below log(4), approaches 0. This figure shows the accuracies when BN is used, with both approaching 100%. GANs are adversarial; the generator and discriminator can't both have 100% accuracy.

Question:

The code for building and training the GAN can be found here. Am I missing something, and have I made a mistake in my implementation, or is there a bug in tf.keras? I'm pretty sure that this is a technical issue and not a theoretical problem that "GAN-hacks" can solve. Note that this only involves using BatchNormalization layers in the discriminator; using them in the generator does not cause this issue.


Solution

  • There is an issue with Tensorflow's BatchNormalization layer in TF 2.0 and 2.1; downgrading to TF 1.15 resolves the problem. The cause of the problem has not yet been determined.

    Here is the relevant GitHub issue: https://github.com/tensorflow/tensorflow/issues/37673