Search code examples
pythonoptimizationdeep-learningpytorchloss-function

Pytorch: Why does altering the scale of the loss functions improve the convergence in some models?


I have a question surrounding a pretty complex loss function I have. This is a variational autoencoder loss function and it is fairly complex.
It is made of two reconstruction losses, KL divergence and a discriminator as a regularizer. All of those losses are on the same scale, but I have found out that increasing one of the reconstruction losses by a factor of 20 (while leaving the rest on the previous scale) heavily increases the performance of my model.
Since I am still fairly novice on DL, I dont completely understand why this happens, or how I could identify this sort of thing on successive models.
Any advice/explanation is greatly appreciated.


Solution

  • To summarize your setting first:

    loss = alpha1 * loss1 + alpha2 * loss2
    

    When computing the gradients for backpropagation, we compute back through this formular. By backpropagating through our error function we get the gradient:

    dError/dLoss
    

    To continue our propagation downwards, we now want to compute dError/dLoss1 and dError/dLoss2.

    dError/dLoss1 can be expanded to dError/dLoss * dLoss/dLoss1 via the cain rule (https://en.wikipedia.org/wiki/Chain_rule). We already computed dError/dLoss so we only need to compute dLoss derived with respect to dLoss1, which is

    dLoss/dLoss1 = alpha1

    The backpropagation now continues until we reach our weights (dLoss1/dWeight). The gradient our weight receives is:

    dError/dWeight = dError/dLoss * dLoss/dLoss1 * dLoss1/dWeight = dError/dLoss * alpha1 * dLoss1/dWeight
    

    As you can see, the gradient used to update our weight does now depend on alpha1, the factor we use to scale Loss1. If we increase alpha1 while not changing alpha2 the gradients depending on Loss1 will have higher different impact than the gradients of Loss2 and therefor changing the optimization of our model.