As opposed to native generative models, the input for this vae is a RGB image. Here if I compile the self.combined
using add_loss
method, the loss goes around 15000 to -22000. Compiling using mse
works fine.
def __init__(self,type = 'landmark'):
self.latent_dim = 128
self.input_shape = (128,128,3)
self.batch_size = 1
self.original_dim = self.latent_dim*self.latent_dim
patch = int(self.input_shape[0] / 2**4)
self.disc_patch = (patch, patch, 1)
optimizer = tf.keras.optimizers.Adam(0.0002, 0.5)
pd = patch_discriminator(type)
self.discriminator = pd.discriminator()
self.discriminator.compile(loss = 'binary_crossentropy',optimizer = optimizer)
self.discriminator.trainable = False
vae = VAE(self.latent_dim,type = type)
encoder = vae.inference_net()
decoder = vae.generative_net()
if type == 'image':
self.orig_out = tf.random.normal(shape = (self.batch_size,128,128,3))
else:
self.orig_out = tf.random.normal(shape = (self.batch_size,128,128,1))
vae_input = tf.keras.layers.Input(shape = self.input_shape)
self.encoder_out = encoder(vae_input)
self.decoder_out = decoder(self.encoder_out[2])
self.generator = tf.keras.Model(vae_input,self.decoder_out)
vae_loss = self.compute_loss()
self.generator.add_loss(vae_loss)
self.generator.compile(optimizer = optimizer)
valid = self.discriminator([self.decoder_out,self.decoder_out])
self.combined = tf.keras.Model(vae_input,valid)
self.combined.add_loss(vae_loss)
self.combined.compile(optimizer = optimizer)
# self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
self.dl = DataLoader()
compute loss computes kl loss for VAE. Initially self.orig_out
is set as normal tensor and is updated in training loop below.
def compute_loss(self):
bce = tf.keras.losses.BinaryCrossentropy()
reconstruction_loss = bce(self.decoder_out,self.orig_out)
reconstruction_loss = self.original_dim*reconstruction_loss
z_mean = self.encoder_out[0]
z_log_var = self.encoder_out[1]
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
return vae_loss
Training loop:
def train(self,batch_size = 1,epochs = 10):
start_time = datetime.datetime.now()
valid = np.ones((batch_size,) + self.disc_patch)
fake = np.zeros((batch_size,) + self.disc_patch)
threshold = epochs//10
for epoch in range(epochs):
for batch_i,(imA,imB,n_batches) in enumerate(self.dl.load_batch(target='landmark',batch_size=batch_size)):
self.orig_out = tf.convert_to_tensor(imB, dtype=tf.float32)
fakeA = self.generator.predict(imA)
d_real_loss = self.discriminator.train_on_batch([imB,imB],valid)
d_fake_loss = self.discriminator.train_on_batch([imB,fakeA],fake)
d_loss = 0.5*np.add(d_real_loss,d_fake_loss)
combined_loss = self.combined.train_on_batch(imA)
#combined_loss = self.combined.train_on_batch(imA,valid)
elapsed_time = datetime.datetime.now() - start_time
print (f"[Epoch {epoch}/{epochs}] [Batch {batch_i}/{n_batches}] [D loss: {d_loss}] [G loss: {combined_loss}] time: {elapsed_time}")
If I compile self.combined
with kl loss using add_loss()
method, I am not able to pass outputs during train_on_batch
as shown above. Thus the generator won't learn and produces random outputs. How do I compile vae with discriminator using kl loss ?
I don't know if this will be the right answer, but VAE can be modeled using Tensorflow more easily since it deals with custom training loops.
You can follow this link which may contain some relevant information for your problem.