Search code examples
pythondeep-learningkerasloss-function

Keras training with batches: Is the training loss computed before or after each optimization step?


this is probably a very basic question, however I wasn't able to find an answer to it: When I train a network with Keras using batches, the console output shows and keeps updating a display of the current loss value of the training set during each training epoch. As I understand it, this loss value is computed over the current batch (as a proxy for the overall loss) and probably averaged with the loss values that were calculated for the previous batches. But there are two possibilities to get the loss value of the current batch: Either before updating the parameters or afterwards. Can someone tell me which of the two is correct? From what I observe I would rather guess it is after the optimization step.

Reason why I ask this question: I was training a network and saw a behavior where the training loss (MSE of two embeddings) would decrease as expected (several orders of magnitude), but the validation loss stayed the same. First I thought it might be due to overfitting. In consequence, as the training dataset is quite big (200k images), I decided to decrease the epoch size to be able to see the validation set evaluated more often, resulting in epochs smaller than trainingSetSize/batchSize. Even then I saw the training loss decreasing from epoch to epoch (validation loss still staying the same), which I found quite intriguing as the network was still in the phase where it saw the training data for the very first time. In my understanding this means that either there is some nasty bug in my setup or the displayed training loss is shown on after taking an optimization step. Otherwise, the loss on a new, never seen batch and the validation set should behave at least similar.

Even if I assume that the loss is calculated after each optimization step: Assuming my network makes no useful progress as suggested by the validation set evaluation, it should also behave arbitrary when seeing a new, never seen batch. Then, the whole decrease in training loss would only be due to the optimization step(which would be very good for the batch at hand but not for other data, obviously, so also a kind of overfitting). This would mean, if the training loss keeps decreasing, that the optimization step per batch gets more effective. I am using Adam optimizer which I know is adaptive, but is it really possible to see a continuous and substantial decrease in training loss while in reality, the network doesn't learn any useful generalization?


Solution

  • The loss is computed before the optimization step. The reason for this is efficiency and has to do with how back-propagation works.

    In particular, suppose we want to minimize ||A(x, z) - y||^2 w.r.t. z. Then when we perform back-propagation we need to evaluate this computational graph:

    A(x, z) -> grad ||. - y||^2 -> backpropagate
    

    Now, if we add a "evaluate loss" to this and evaluate the loss before updating the parameters the computational graph would look like this

               >  grad ||. - y||^2 -> backpropagate
             /
    A(x, z) 
             \
               >  ||. - y||^2
    

    On the other hand, if we evaluate the loss after updating them, the graph would look like this

    A(x, z) -> grad ||. - y||^2 -> backpropagate -> A(x, z) -> ||. - y||^2
    

    Hence, if we evaluate the loss after updating, we need to compute A(x, z) twice, whereas if we compute it before updating we only need to compute it once. Hence, computing it before updating becomes twice as fast.