Search code examples
tensorflowkerasdeep-learningpytorchbatch-normalization

Why does Keras BatchNorm produce different output than PyTorch?


Torch:'1.9.0+cu111'

Tensorflow-gpu:'2.5.0'

I came across a strange thing, when using the Batch Normal layer of tensorflow 2.5 and the BatchNorm2d layer of Pytorch 1.9 to calculate the same Tensor , and the results were quite different (TensorFlow is close to 1, Pytorch is close to 0).I thought at first it was the difference between the momentum and epsilon , but after changing them to the same, the result was the same.

from torch import nn
import torch
x = torch.ones((20, 100, 35, 45))
a = nn.Sequential(
            # nn.Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), padding=0, bias=True),
            nn.BatchNorm2d(100)
        )
b = a(x)

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import *
x = tf.ones((20, 35, 45, 100))
a = keras.models.Sequential([
            # Conv2D(128, (1, 1), (1, 1), padding='same', use_bias=True),
            BatchNormalization()
        ])
b = a(x)

The result of TensorFlow

The result of Pytorch


Solution

  • Batchnormalization works differently in training and inference,

    During training (i.e. when using fit() or when calling the layer/model with the argument training=True), the layer normalizes its output using the mean and standard deviation of the current batch of inputs. That is to say, for each channel being normalized, the layer returns

    gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta
    

    where:

    • epsilon is small constant (configurable as part of the constructor arguments)
    • gamma is a learned scaling factor (initialized as 1), which can be disabled by passing scale=False to the constructor.
    • beta is a learned offset factor (initialized as 0), which can be disabled by passing center=False to the constructor.

    During inference (i.e. when using evaluate() or predict() or when calling the layer/model with the argument training=False (which is the default), the layer normalizes its output using a moving average of the mean and standard deviation of the batches it has seen during training. That is to say, it returns

    gamma * (batch - self.moving_mean) / sqrt(self.moving_var + epsilon) + beta.
    

    self.moving_mean and self.moving_var are non-trainable variables that are updated each time the layer in called in training mode, as such:

        moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)
        moving_var = moving_var * momentum + var(batch) * (1 - momentum)
    

    ref: https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization

    If you run the pytorch batchnorm in eval mode, you get close results (the rest of the discrepancy comes from the different internal implementation, parameter choices, etc.),

    from torch import nn
    import torch
    x = torch.ones((1, 2, 2, 2))
    a = nn.Sequential(
                # nn.Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), padding=0, bias=True),
                nn.BatchNorm2d(2)
            )
    a.eval()
    b = a(x)
    print(b)
    import tensorflow as tf
    import tensorflow.keras as keras
    from tensorflow.keras.layers import *
    x = tf.ones((1, 2, 2, 2))
    a = keras.models.Sequential([
                # Conv2D(128, (1, 1), (1, 1), padding='same', use_bias=True),
                BatchNormalization()
            ])
    b = a(x)
    print(b)
    

    out:

    tensor([[[[1.0000, 1.0000],
              [1.0000, 1.0000]],
    
             [[1.0000, 1.0000],
              [1.0000, 1.0000]]]], grad_fn=<NativeBatchNormBackward>)
    tf.Tensor(
    [[[[0.9995004 0.9995004]
       [0.9995004 0.9995004]]
    
      [[0.9995004 0.9995004]
       [0.9995004 0.9995004]]]], shape=(1, 2, 2, 2), dtype=float32)