Search code examples
pythonmachine-learningpytorchbatch-normalizationdropout

Which PyTorch modules are affected by model.eval() and model.train()?


The model.eval() method modifies certain modules (layers) which are required to behave differently during training and inference. Some examples are listed in the docs:

This has [an] effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

Is there an exhaustive list of which modules are affected?


Solution

  • In addition to info provided by @iacob:

    Base class Module Criteria
    RNNBase RNN
    LSTM
    GRU
    dropout > 0 (default: 0)
    Transformer layers Transformer
    TransformerEncoder
    TransformerDecoder
    dropout > 0 (Transformer default: 0.1)
    Lazy variants LazyBatchNorm
    currently nightly
    merged PR
    track_running_stats=True