Search code examples
tensorflowkerasautoencoderloss

Getting NAN losses when fitting VAE in Keras


I am trying to build a Variational Autoencoder on cifar10 images with Keras. It works perfectly on mnist data. But with cifar10, my losses (reconstruction loss and KL loss) are NAN when I call the method fit as you can see here: NAN loss

Here is my VAE with a custom training step:

Note: cifar10 images shape = (32, 32, 3), latent dimension = 2

class VAE(Model):
  
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)

        # encoder and decoder
        self.encoder = encoder
        self.decoder = decoder

        # losses metrics
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
      with tf.GradientTape() as tape:
        # see 4. Encoder
        z_mu, z_sigma, z = self.encoder(data)
        z_decoded = self.decoder(z)

        # compute the losses
        reconstruction_loss = tf.reduce_mean(
                  tf.reduce_sum(
                      keras.losses.binary_crossentropy(data, z_decoded), axis=(1, 2)
                  )
              )
        kl_loss = -(1 + z_sigma - z_mu**2 - tf.exp(z_sigma)) / 2
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        total_loss = reconstruction_loss + kl_loss

        # gradients
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # update losses
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        #  return the final losses
        return {
              "loss": self.total_loss_tracker.result(),
              "reconstruction_loss": self.reconstruction_loss_tracker.result(),
              "kl_loss": self.kl_loss_tracker.result(),
          }

My encoder: encoder graph

My decoder: decoder graph

Does anyone have an idea?


Solution

  • In case this helps someone, I faced the exact problem and what fixed it for me was sticking to binary_crossentropy but making sure that the data was normalized, that is, all the image pixel values were between 0 and 1. So, something like this might help:

    datagen = ImageDataGenerator(rescale=1./255, <anything else you want>)
    
    

    Keeping the numbers bounded between 0 and 1 is important because otherwise the numbers may iterate in a positive feedback loop.