Search code examples
machine-learningtensorflowdeep-learningbackpropagationloss-function

batch_loss and total_loss=tf.get_total_loss() in tensorflow


I get a problem when I read im2txt source code in im2txt.

There are batch_loss and total_loss: batch_loss is computed for every batch data, and is added into tf.Graphkeys.LOSSES by tf.add_loss(batch_loss) call. The total_loss is got by tf.losses.get_total_loss(), which average the all loss in tf.Graphkeys.LOSSES.

Question: why parameters are updated by total_loss? I was confused by this problem many days.


Solution

  • The summary of discussion in the comments:

    The training loss is computed in the forward pass over the mini-batch. But the actual loss values aren't needed to begin the backprop. The backprop is started with the error signal, which equals to the loss function derivative evaluated at the values from the forward pass. So the loss value doesn't affect the parameters update and is reported simply to monitor the training process. For example, if the loss does not decrease, it's a sign to double check the neural network model and hyperparameters. So it's not a big deal to smooth the reported loss through averaging just to make a chart look nicer.

    See this post for more details.