Search code examples
tensorflowdeep-learningloss-functiongenerative-adversarial-network

Tensorflow apply_gradients() with multiple losses


I am training a model(VAEGAN) with intermediate outputs and I have two losses,

  • KL Divergence loss I compute from output layer
  • Similarity (rec) loss I compute from an intermediate layer.

Can I simply sum them up and apply gradients like below?

with tf.GradientTape() as tape:
    z_mean, z_log_sigma, z_encoder_output = self.encoder(real_images, training = True)
    kl_loss = self.kl_loss_fn(z_mean, z_log_sigma) * kl_loss_coeff

    fake_images = self.decoder(z_encoder_output)
    fake_inter_activations, logits_fake = self.discriminator(fake_images, training = True)
    real_inter_activations, logits_real = self.discriminator(real_images, training = True)

    rec_loss = self.rec_loss_fn(fake_inter_activations, real_inter_activations) * rec_loss_coeff

    total_encoder_loss = kl_loss + rec_loss

grads = tape.gradient(total_encoder_loss, self.encoder.trainable_weights)
self.e_optimizer.apply_gradients(zip(grads, self.encoder.trainable_weights))

or do I need to seperate them like below while keeping tape persistent?

with tf.GradientTape(persistent = True) as tape:
    z_mean, z_log_sigma, z_encoder_output = self.encoder(real_images, training = True)
    kl_loss = self.kl_loss_fn(z_mean, z_log_sigma) * kl_loss_coeff
    
    fake_images = self.decoder(z_encoder_output)
    fake_inter_activations, logits_fake = self.discriminator(fake_images, training = True)
    real_inter_activations, logits_real = self.discriminator(real_images, training = True)
    
    rec_loss = self.rec_loss_fn(fake_inter_activations, real_inter_activations) * rec_loss_coeff

grads_kl_loss = tape.gradient(kl_loss, self.encoder.trainable_weights)
self.e_optimizer.apply_gradients(zip(grads_kl_loss, self.encoder.trainable_weights))

grads_rec_loss = tape.gradient(rec_loss, self.encoder.trainable_weights)
self.e_optimizer.apply_gradients(zip(grads_rec_loss, self.encoder.trainable_weights))

Solution

  • Yes, you can generally sum the losses and compute a single gradient. Since the gradient of a sum is the sum of the respective gradients, so the step taken by the summed loss is the same as taking both steps one after another.

    Here's a simple example: Say you have two weights, and you are currently at the point (1, 3) ("starting point"). The gradient for loss 1 is (2, -4) and the gradient for loss 2 is (1, 2).

    • If you apply the steps one after the other, you will first move to (3, -1) and then to (4, 1).
    • If you sum the gradients first, the overall step is (3, -2). Following this direction from the starting point gets you to (4, 1) as well.