Search code examples
pythontensorflowmachine-learningtf.kerasbatch-normalization

How to implement Batch Normalization on tensorflow with Keras as a high-level API


BatchNormalization (BN) operates slightly differently when in training and in inference. In training, it uses the average and variance of the current mini-batch to scale its inputs; this means that the exact result of the application of batch normalization depends not only on the current input, but also on all other elements of the mini-batch. This is clearly not desirable when in inference mode, where we want a deterministic result. Therefore, in that case, a fixed statistic of the global average and variance over the entire training set is used.

In Tensorflow, this behavior is controlled by a boolean switch training that needs to be specified when calling the layer, see https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization. How do I deal with this switch when using Keras high-level API? Am I correct in assuming that it is dealt with automatically, depending whether we are using model.fit(x, ...) or model.predict(x, ...)?


To test this, I have written this example. We start with a random distribution and we want to classify whether the input is positive or negative. However, we also have a test dataset coming from a different distribution where the inputs are displaced by 2 (and consequently the labels check whether x>2).

import numpy as np
from math import ceil
from tensorflow.python.data import Dataset
from tensorflow.python.keras import Input, Model
from tensorflow.python.keras.layers import Dense, BatchNormalization

np.random.seed(18)
xt = np.random.randn(10_000, 1)
yt = np.array([[int(x > 0)] for x in xt])
train_data = Dataset.from_tensor_slices((xt, yt)).shuffle(10_000).repeat().batch(32).prefetch(2)

xv = np.random.randn(100, 1)
yv = np.array([[int(x > 0)] for x in xv])
valid_data = Dataset.from_tensor_slices((xv, yv)).repeat().batch(32).prefetch(2)

xs = np.random.randn(100, 1) + 2
ys = np.array([[int(x > 2)] for x in xs])
test_data = Dataset.from_tensor_slices((xs, ys)).repeat().batch(32).prefetch(2)

x = Input(shape=(1,))
a = BatchNormalization()(x)
a = Dense(8, activation='sigmoid')(a)
a = BatchNormalization()(a)
y = Dense(1, activation='sigmoid')(a)
model = Model(inputs=x, outputs=y, )
model.summary()

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(train_data, epochs=10, steps_per_epoch=ceil(10_000 / 32), validation_data=valid_data,
          validation_steps=ceil(100 / 32))
zs = model.predict(test_data, steps=ceil(100 / 32))
print(sum([ys[i] == int(zs[i] > 0.5) for i in range(100)]))

Running the code prints the value 0.5, meaning that half the examples are labeled properly. This is what I would expect if the system was using the global statistics on the training set to implement BN.

If we change the BN layers to read

x = Input(shape=(1,))
a = BatchNormalization()(x, training=True)
a = Dense(8, activation='sigmoid')(a)
a = BatchNormalization()(a, training=True)
y = Dense(1, activation='sigmoid')(a)

and run the code again we find 0.87. Forcing always the training state, the percentage of correct prediction has changed. This is consistent with the idea that model.predict(x, ...) is now using the statistic of the mini-batch to implement BN, and is therefore able to slightly "correct" the mismatch in the source distributions between training and test data.

Is that correct?


Solution

  • If I'm understanding your question correctly, then yes, keras does automatically manage training vs inference behavior based on fit vs predict/evaluate. The flag is called learning_phase, and it determines the behavior of batch norm, dropout, and potentially other things. The current learning phase can be seen with keras.backend.learning_phase(), and set with keras.backend.set_learning_phase().

    https://keras.io/backend/#learning_phase