Search code examples
pythontensorflowtensorflow2.0batch-normalization

tf.keras.layers.BatchNormalization with trainable=False appears to not update its internal moving mean and variance


I am trying to find out, how exactly does BatchNormalization layer behave in TensorFlow. I came up with the following piece of code which to the best of my knowledge should be a perfectly valid keras model, however the mean and variance of BatchNormalization doesn't appear to be updated.

From docs https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization

in the case of the BatchNormalization layer, setting trainable = False on the layer means that the layer will be subsequently run in inference mode (meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).

I expect the model to return a different value with each subsequent predict call. What I see, however, are the exact same values returned 10 times. Can anyone explain to me why does the BatchNormalization layer not update its internal values?

import tensorflow as tf
import numpy as np

if __name__ == '__main__':

    np.random.seed(1)
    x = np.random.randn(3, 5) * 5 + 0.3

    bn = tf.keras.layers.BatchNormalization(trainable=False, epsilon=1e-9)
    z = input = tf.keras.layers.Input([5])
    z = bn(z)

    model = tf.keras.Model(inputs=input, outputs=z)

    for i in range(10):
        print(x)
        print(model.predict(x))
        print()

I use TensorFlow 2.1.0


Solution

  • Okay, I found the mistake in my assumptions. The moving average is being updated during training not during inference as I thought. This makes perfect sense, as updating the moving averages during inference would likely result in an unstable production model (for example a long sequence of highly pathological input samples [e.g. such that their generating distribution differs drastically from the one on which the network was trained] could potentially bias the network and result in worse performance on valid input samples).

    The trainable parameter is useful when you're fine-tuning a pretrained model and want to freeze some of the layers of the network even during training. Because when you call model.predict(x) (or even model(x) or model(x, training=False)), the layer automatically uses the moving averages instead of batch averages.

    The code below demonstrates this clearly

    import tensorflow as tf
    import numpy as np
    
    if __name__ == '__main__':
    
        np.random.seed(1)
        x = np.random.randn(10, 5) * 5 + 0.3
    
        z = input = tf.keras.layers.Input([5])
        z = tf.keras.layers.BatchNormalization(trainable=True, epsilon=1e-9, momentum=0.99)(z)
    
        model = tf.keras.Model(inputs=input, outputs=z)
        
        # a dummy loss function
        model.compile(loss=lambda x, y: (x - y) ** 2)
    
        # a dummy fit just to update the batchnorm moving averages
        model.fit(x, x, batch_size=3, epochs=10)
        
        # first predict uses the moving averages from training
        pred = model(x).numpy()
        print(pred.mean(axis=0))
        print(pred.var(axis=0))
        print()
        
        # outputs the same thing as previous predict
        pred = model(x).numpy()
        print(pred.mean(axis=0))
        print(pred.var(axis=0))
        print()
        
        # here calling the model with training=True results in update of moving averages
        # furthermore, it uses the batch mean and variance as in training, 
        # so the result is very different
        pred = model(x, training=True).numpy()
        print(pred.mean(axis=0))
        print(pred.var(axis=0))
        print()
        
        # here we see again that the moving averages are used but they differ slightly after
        # the previous call, as expected
        pred = model(x).numpy()
        print(pred.mean(axis=0))
        print(pred.var(axis=0))
        print()
    

    In the end, I found that the documentation (https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization) mentions this:

    1. When performing inference using a model containing batch normalization, it is generally (though not always) desirable to use accumulated statistics rather than mini-batch statistics. This is accomplished by passing training=False when calling the model, or using model.predict.

    Hopefully this will help someone with similar misunderstanding in the future.