Search code examples
pytorchnormalizationloss

Should I normalize or weight losses when combining them in pytorch?


network architecture

I have a neural network with 3 heads, one of them with a focal loss and two others with L1 losses. They are combined by summing: loss = hm_loss + off_loss + wh_loss However the range of typical values for loss elements are different. Is it an issue? Should I weight the loss elements, or should I normalize the network outputs?


Solution

  • This is a typical challenge when performing multi-task learning. There are many methods to handle this, but as for all things in this field, there is no single solution to solve them all. The most straightforward approach is to weigh the different loss components indeed. You can do so by performing a grid search or random search on the three weights or try and level the three components of your loss by looking at the orders of magnitude for each of them. The general idea behind this is if you're giving high precedence for one of the loss terms, then the gradient corresponding to this term will be much more prominent when performing back propagation and parameter update.

    I recommend you read more on multi-task learning. For example you could start with Multi-Task Learning for Dense Prediction Tasks A Survey: Simon Vandenhende et al., in TPAMI'21.