Search code examples
pytorchloss

In pytorch, what situations the loss function need to inherit nn.module?


I am confused about the loss function in PyTorch. Some people define the loss function as a normal python function while others define the loss function by defining a class that inherits nn.Module. So I want to know what situations we need to define the loss function by inheriting nn.Module? Many thanks.


Solution

  • Generally, inheritance from nn.Module is only necessary when you want to have trainable variables in this module, otherwise it's optional to inherit it.

    So same applies to loss functions, if it contains no such variables (which I assume is the major case), no inheritance is needed.