Search code examples
tensorflowmachine-learningneural-networkbatch-normalization

How does BatchNormalization work on an example?


I am trying to understand batchnorm. My humble example

layer1 = tf.keras.layers.BatchNormalization(scale=False, center=False)
x = np.array([[3.,4.]])
out = layer1(x)
print(out)

Prints

tf.Tensor([[2.99850112 3.9980015 ]], shape=(1, 2), dtype=float64)

My attempt to reproduce it

e=0.001
m = np.sum(x)/2
b = np.sum((x - m)**2)/2 
x_=(x-m)/np.sqrt(b+e)
print(x_)

It prints

[[-0.99800598  0.99800598]]

What am I doing wrong?


Solution

  • Two problems here.

    First, batch norm has two "modes": Training, where normalization is done via the batch statistics, and inference, where normalization is done via "population statistics" that are collected from batches during training. Per default, keras layers/models function in inference mode, and you need to specify training=True in their call to change this (there are other ways, but that is the simplest one).

    layer1 = tf.keras.layers.BatchNormalization(scale=False, center=False)
    x = np.array([[3.,4.]], dtype=np.float32)
    out = layer1(x, training=True)
    print(out)
    

    This prints tf.Tensor([[0. 0.]], shape=(1, 2), dtype=float32). Still not right!

    Second, batch norm normalizes over the batch axis, separately for each feature. However, the way you specify the input (as a 1x2 array) is basically a single input (batch size 1) with two features. Batch norm just normalizes each feature to mean 0 (standard deviation is not defined). Instead, you want two inputs with a single feature:

    layer1 = tf.keras.layers.BatchNormalization(scale=False, center=False)
    x = np.array([[3.],[4.]], dtype=np.float32)
    out = layer1(x, training=True)
    print(out)
    

    This prints

    tf.Tensor(
    [[-0.99800634]
     [ 0.99800587]], shape=(2, 1), dtype=float32)
    

    Alternatively, specify the "feature axis":

    layer1 = tf.keras.layers.BatchNormalization(axis=0, scale=False, center=False)
    x = np.array([[3.,4.]], dtype=np.float32)
    out = layer1(x, training=True)
    print(out)
    

    Note that the input shape is "wrong", but we told batchnorm that axis 0 is the feature axis (it defaults to -1, the last axis). This will also give the desired result:

    tf.Tensor([[-0.99800634  0.99800587]], shape=(1, 2), dtype=float32)