Search code examples
pythonkerastensorflow2.0autoencodergenerative-adversarial-network

VAE with a discriminator compiling problem


As opposed to native generative models, the input for this vae is a RGB image. Here if I compile the self.combined using add_loss method, the loss goes around 15000 to -22000. Compiling using mse works fine.

    def __init__(self,type = 'landmark'):

        self.latent_dim = 128
        self.input_shape = (128,128,3)
        self.batch_size = 1
        self.original_dim = self.latent_dim*self.latent_dim
        patch = int(self.input_shape[0] / 2**4)
        self.disc_patch = (patch, patch, 1)

        optimizer = tf.keras.optimizers.Adam(0.0002, 0.5)

        pd = patch_discriminator(type)
        self.discriminator = pd.discriminator()
        self.discriminator.compile(loss = 'binary_crossentropy',optimizer = optimizer)
        self.discriminator.trainable = False

        vae = VAE(self.latent_dim,type = type)
        encoder = vae.inference_net()
        decoder = vae.generative_net()

        if type == 'image':
            self.orig_out = tf.random.normal(shape = (self.batch_size,128,128,3))
        else:
            self.orig_out = tf.random.normal(shape = (self.batch_size,128,128,1))

        vae_input = tf.keras.layers.Input(shape = self.input_shape)
        self.encoder_out = encoder(vae_input)
        self.decoder_out = decoder(self.encoder_out[2])

        self.generator = tf.keras.Model(vae_input,self.decoder_out)
        vae_loss = self.compute_loss()
        self.generator.add_loss(vae_loss)
        self.generator.compile(optimizer = optimizer)

        valid = self.discriminator([self.decoder_out,self.decoder_out])
        self.combined = tf.keras.Model(vae_input,valid)
        self.combined.add_loss(vae_loss)
        self.combined.compile(optimizer = optimizer)
        # self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

        self.dl = DataLoader()

compute loss computes kl loss for VAE. Initially self.orig_out is set as normal tensor and is updated in training loop below.

    def compute_loss(self):
        bce = tf.keras.losses.BinaryCrossentropy()
        reconstruction_loss = bce(self.decoder_out,self.orig_out)
        reconstruction_loss = self.original_dim*reconstruction_loss
        z_mean = self.encoder_out[0]
        z_log_var = self.encoder_out[1]
        kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
        kl_loss = K.sum(kl_loss, axis=-1)
        kl_loss *= -0.5
        vae_loss = K.mean(reconstruction_loss + kl_loss)
        return vae_loss

Training loop:

    def train(self,batch_size = 1,epochs = 10):
        start_time = datetime.datetime.now()
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)
        threshold = epochs//10

        for epoch in range(epochs):
            for batch_i,(imA,imB,n_batches) in enumerate(self.dl.load_batch(target='landmark',batch_size=batch_size)):
                self.orig_out = tf.convert_to_tensor(imB, dtype=tf.float32)
                fakeA = self.generator.predict(imA)

                d_real_loss = self.discriminator.train_on_batch([imB,imB],valid)
                d_fake_loss = self.discriminator.train_on_batch([imB,fakeA],fake)
                d_loss = 0.5*np.add(d_real_loss,d_fake_loss)

                combined_loss = self.combined.train_on_batch(imA)
                #combined_loss = self.combined.train_on_batch(imA,valid)

                elapsed_time = datetime.datetime.now() - start_time


                print (f"[Epoch {epoch}/{epochs}] [Batch {batch_i}/{n_batches}] [D loss: {d_loss}] [G loss: {combined_loss}] time: {elapsed_time}")

If I compile self.combined with kl loss using add_loss() method, I am not able to pass outputs during train_on_batch as shown above. Thus the generator won't learn and produces random outputs. How do I compile vae with discriminator using kl loss ?


Solution

  • I don't know if this will be the right answer, but VAE can be modeled using Tensorflow more easily since it deals with custom training loops.
    You can follow this link which may contain some relevant information for your problem.