I am trying to understand batchnorm. My humble example
layer1 = tf.keras.layers.BatchNormalization(scale=False, center=False)
x = np.array([[3.,4.]])
out = layer1(x)
print(out)
Prints
tf.Tensor([[2.99850112 3.9980015 ]], shape=(1, 2), dtype=float64)
My attempt to reproduce it
e=0.001
m = np.sum(x)/2
b = np.sum((x - m)**2)/2
x_=(x-m)/np.sqrt(b+e)
print(x_)
It prints
[[-0.99800598 0.99800598]]
What am I doing wrong?
Two problems here.
First, batch norm has two "modes": Training, where normalization is done via the batch statistics, and inference, where normalization is done via "population statistics" that are collected from batches during training. Per default, keras layers/models function in inference mode, and you need to specify training=True
in their call to change this (there are other ways, but that is the simplest one).
layer1 = tf.keras.layers.BatchNormalization(scale=False, center=False)
x = np.array([[3.,4.]], dtype=np.float32)
out = layer1(x, training=True)
print(out)
This prints tf.Tensor([[0. 0.]], shape=(1, 2), dtype=float32)
. Still not right!
Second, batch norm normalizes over the batch axis, separately for each feature. However, the way you specify the input (as a 1x2 array) is basically a single input (batch size 1) with two features. Batch norm just normalizes each feature to mean 0 (standard deviation is not defined). Instead, you want two inputs with a single feature:
layer1 = tf.keras.layers.BatchNormalization(scale=False, center=False)
x = np.array([[3.],[4.]], dtype=np.float32)
out = layer1(x, training=True)
print(out)
This prints
tf.Tensor(
[[-0.99800634]
[ 0.99800587]], shape=(2, 1), dtype=float32)
Alternatively, specify the "feature axis":
layer1 = tf.keras.layers.BatchNormalization(axis=0, scale=False, center=False)
x = np.array([[3.,4.]], dtype=np.float32)
out = layer1(x, training=True)
print(out)
Note that the input shape is "wrong", but we told batchnorm that axis 0 is the feature axis (it defaults to -1, the last axis). This will also give the desired result:
tf.Tensor([[-0.99800634 0.99800587]], shape=(1, 2), dtype=float32)