Search code examples
tensorflowkerastensorflow2.0tf.kerastensorflow-probability

Where exactly are the KL losses used after the forward pass?


I've noticed that the KL part of the loss is added to the list self._losses of the Layer class when self.add_loss is called from the call method of the DenseVariational (i.e. during the forward pass).

But how is this list self._losses (or the method losses of the same Layer class) treated during training? Where is it called from during training? For example, are they summed or average before adding them to the final loss? I would like to SEE the ACTUAL CODE.

I would like to know how exactly these losses are combined with the loss that you specify in the fit method. Can you provide me with the code that combines them? Note that I am interested in the Keras that is shipped with TensorFlow (because that's the one I am using).


Solution

  • Actually, the part where the total loss is computed is in compile method of Model class, specifically in this line:

        # Compute total loss.
        # Used to keep track of the total loss value (stateless).
        # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) +
        #                   loss_weight_2 * output_2_loss_fn(...) +
        #                   layer losses.
        self.total_loss = self._prepare_total_loss(masks)
    

    The _prepare_total_loss method adds the regularization and layer losses to the total loss (i.e. so all the losses are summed together) and then averages them over the batch axis in these lines:

            # Add regularization penalties and other layer-specific losses.
            for loss_tensor in self.losses:
                total_loss += loss_tensor
    
        return K.mean(total_loss)
    

    Actually, self.losses is not the attribute of the Model class; rather, it's the attribute of the parent class, i.e. Network, which returns all the layer-specific losses as a list. Further, to resolve any confusion, total_loss at above code is a single tensor which is eqaul to the summation of all the losses in the model (i.e. loss function values, and layer-specific losses). Note that loss functions by definition must return a single loss value per each input sample (not the whole batch). Therefore, K.mean(total_loss) would average all these values over the batch axis to one final loss value which should be minimized by optimizer.


    As for the tf.keras this is more or less the same as native keras; however, the structures and flow of things is a bit different which are explained below.

    First, in compile method of Model class a loss container is created which holds and computes value of loss functions:

      self.compiled_loss = compile_utils.LossesContainer(
          loss, loss_weights, output_names=self.output_names)
    

    Next, in train_step method of Model class this container is called to compute the loss value of a batch:

      loss = self.compiled_loss(
          y, y_pred, sample_weight, regularization_losses=self.losses)
    

    As you can see above self.losses is passed to this container. The self.losses, as in native Keras implementation, contains all the layer-specific loss values with the only difference that in tf.keras it's implemented in Layer class (instead of Network class as in native Keras). Note that Model is a subclass of Network which itself is a subclass of Layer. Now, let's see how regularization_losses would be treated in the __call__ method of LossesContainer (these lines):

      if (loss_obj.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE or
          loss_obj.reduction == losses_utils.ReductionV2.AUTO):
        loss_value = losses_utils.scale_loss_for_distribution(loss_value)
    
    
      loss_values.append(loss_value)
      loss_metric_values.append(loss_metric_value)
    
    
    if regularization_losses:
      regularization_losses = losses_utils.cast_losses_to_common_dtype(
          regularization_losses)
      reg_loss = math_ops.add_n(regularization_losses)
      loss_metric_values.append(reg_loss)
      loss_values.append(losses_utils.scale_loss_for_distribution(reg_loss))
    
    
    if loss_values:
      loss_metric_values = losses_utils.cast_losses_to_common_dtype(
          loss_metric_values)
      total_loss_metric_value = math_ops.add_n(loss_metric_values)
      self._loss_metric.update_state(
          total_loss_metric_value, sample_weight=batch_dim)
    
    
      loss_values = losses_utils.cast_losses_to_common_dtype(loss_values)
      total_loss = math_ops.add_n(loss_values)
      return total_loss
    

    As you can see, regularization_losses will be added to the total_loss which would hold the summation of layer-specific losses and sum of average of all the loss functions over the batch axis (therefore, it would be a single value).