I've noticed that the KL part of the loss is added to the list self._losses
of the Layer
class when self.add_loss
is called from the call
method of the DenseVariational
(i.e. during the forward pass).
But how is this list self._losses
(or the method losses
of the same Layer
class) treated during training? Where is it called from during training? For example, are they summed or average before adding them to the final loss? I would like to SEE the ACTUAL CODE.
I would like to know how exactly these losses are combined with the loss that you specify in the fit
method. Can you provide me with the code that combines them? Note that I am interested in the Keras that is shipped with TensorFlow (because that's the one I am using).
Actually, the part where the total loss is computed is in compile
method of Model
class, specifically in this line:
# Compute total loss.
# Used to keep track of the total loss value (stateless).
# eg., total_loss = loss_weight_1 * output_1_loss_fn(...) +
# loss_weight_2 * output_2_loss_fn(...) +
# layer losses.
self.total_loss = self._prepare_total_loss(masks)
The _prepare_total_loss
method adds the regularization and layer losses to the total loss (i.e. so all the losses are summed together) and then averages them over the batch axis in these lines:
# Add regularization penalties and other layer-specific losses.
for loss_tensor in self.losses:
total_loss += loss_tensor
return K.mean(total_loss)
Actually, self.losses
is not the attribute of the Model
class; rather, it's the attribute of the parent class, i.e. Network
, which returns all the layer-specific losses as a list. Further, to resolve any confusion, total_loss
at above code is a single tensor which is eqaul to the summation of all the losses in the model (i.e. loss function values, and layer-specific losses). Note that loss functions by definition must return a single loss value per each input sample (not the whole batch). Therefore, K.mean(total_loss)
would average all these values over the batch axis to one final loss value which should be minimized by optimizer.
As for the tf.keras
this is more or less the same as native keras
; however, the structures and flow of things is a bit different which are explained below.
First, in compile
method of Model
class a loss container is created which holds and computes value of loss functions:
self.compiled_loss = compile_utils.LossesContainer(
loss, loss_weights, output_names=self.output_names)
Next, in train_step
method of Model
class this container is called to compute the loss value of a batch:
loss = self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)
As you can see above self.losses
is passed to this container. The self.losses
, as in native Keras implementation, contains all the layer-specific loss values with the only difference that in tf.keras
it's implemented in Layer
class (instead of Network
class as in native Keras). Note that Model
is a subclass of Network
which itself is a subclass of Layer
. Now, let's see how regularization_losses
would be treated in the __call__
method of LossesContainer
(these lines):
if (loss_obj.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE or
loss_obj.reduction == losses_utils.ReductionV2.AUTO):
loss_value = losses_utils.scale_loss_for_distribution(loss_value)
loss_values.append(loss_value)
loss_metric_values.append(loss_metric_value)
if regularization_losses:
regularization_losses = losses_utils.cast_losses_to_common_dtype(
regularization_losses)
reg_loss = math_ops.add_n(regularization_losses)
loss_metric_values.append(reg_loss)
loss_values.append(losses_utils.scale_loss_for_distribution(reg_loss))
if loss_values:
loss_metric_values = losses_utils.cast_losses_to_common_dtype(
loss_metric_values)
total_loss_metric_value = math_ops.add_n(loss_metric_values)
self._loss_metric.update_state(
total_loss_metric_value, sample_weight=batch_dim)
loss_values = losses_utils.cast_losses_to_common_dtype(loss_values)
total_loss = math_ops.add_n(loss_values)
return total_loss
As you can see, regularization_losses
will be added to the total_loss
which would hold the summation of layer-specific losses and sum of average of all the loss functions over the batch axis (therefore, it would be a single value).