Search code examples
pythonneural-networkdeep-learningpytorchbatch-normalization

BatchNorm momentum convention PyTorch


Is the batchnorm momentum convention (default=0.1) correct as in other libraries e.g. Tensorflow it seems to usually be 0.9 or 0.99 by default? Or maybe we are just using a different convention?


Solution

  • It seems that the parametrization convention is different in pytorch than in tensorflow, so that 0.1 in pytorch is equivalent to 0.9 in tensorflow.

    To be more precise:

    In Tensorflow:

    running_mean = decay*running_mean + (1-decay)*new_value
    

    In PyTorch:

    running_mean = (1-decay)*running_mean + decay*new_value
    

    This means that a value of decay in PyTorch is equivalent to a value of (1-decay) in Tensorflow.