Search code examples
pythontensorflowkerasbatch-normalization

tf.keras.BatchNormalization giving unexpected output


import tensorflow as tf
tf.enable_eager_execution()

print(tf.keras.layers.BatchNormalization()(tf.convert_to_tensor([[5.0, 70.0], [5.0, 60.0]])))
print(tf.contrib.layers.batch_norm(tf.convert_to_tensor([[5.0, 70.0], [5.0, 60.0]])))"

The output of the above code (in Tensorflow 1.15) is:

tf.Tensor([[ 4.99 69.96] [ 4.99 59.97]], shape=(2, 2), dtype=float32)
tf.Tensor([[ 0. 0.99998] [ 0. -0.99998]], shape=(2, 2), dtype=float32)

My problem is why the same function is giving completely different outputs. I also played with some of the parameters of the functions but the result was the same. For me, the second output is what I want. Also, pytorch's batchnorm also gives the same output as second one. So I'm thinking its the issue with keras.

Know how to fix batchnorm in keras?


Solution

  • Batch Normalization layer has different behavior in training vs. inferencing:

    1. During training (i.e. when using fit() or when calling the layer/model with the argument training=True), the layer normalizes its output using the mean and standard deviation of the current batch of inputs.

    2. During inference (i.e. when using evaluate() or predict() or when calling the layer/model with the argument training=False (which is the default), the layer normalizes its output using a moving average of the mean and standard deviation of the batches it has seen during training.

    So, the first result is due to default training=False and the second is due to default is_training=True.

    If you want the same result you may try:

    x = tf.convert_to_tensor([[5.0, 70.0], [5.0, 60.0]])
    print(tf.keras.layers.BatchNormalization()(x, training=True).numpy().tolist())
    print(tf.contrib.layers.batch_norm(x).numpy().tolist())
    #output
    #[[0.0, 0.9999799728393555], [0.0, -0.9999799728393555]]
    #[[0.0, 0.9999799728393555], [0.0, -0.9999799728393555]]
    

    or

    x = tf.convert_to_tensor([[5.0, 70.0], [5.0, 60.0]])
    print(tf.keras.layers.BatchNormalization()(x).numpy().tolist())
    print(tf.contrib.layers.batch_norm(x, is_training=False).numpy().tolist())
    #output
    #[[4.997501850128174, 69.96502685546875], [4.997501850128174, 59.97002410888672]]
    #[[4.997501850128174, 69.96502685546875], [4.997501850128174, 59.97002410888672]]