Search code examples
tensorflowbatch-normalization

Why do batch_normalization produce all-zero output when training = True but produce non-zero output when training = False?


I am following the Tensorflow tutorial https://www.tensorflow.org/guide/migrate. Here is an example:

def model(x, training, scope='model'):
  with v1.variable_scope(scope, reuse=v1.AUTO_REUSE):
    x = v1.layers.conv2d(x, 32, 3, activation=v1.nn.relu,
          kernel_regularizer=lambda x:0.004*tf.reduce_mean(x**2))
    x = v1.layers.max_pooling2d(x, (2, 2), 1)
    x = v1.layers.flatten(x)
    x = v1.layers.dropout(x, 0.1, training=training)
    x = v1.layers.dense(x, 64, activation=v1.nn.relu)
    x = v1.layers.batch_normalization(x, training=training)
    x = v1.layers.dense(x, 10)
    return x
train_data = tf.ones(shape=(1, 28, 28, 1))
test_data = tf.ones(shape=(1, 28, 28, 1))
train_out = model(train_data, training=True)
test_out = model(test_data, training=False)
print(train_out)
print(test_out)

The train_out where training=True is

tf.Tensor([[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(1, 10), dtype=float32)

While test_out with training=False is something random non-zero vector

tf.Tensor(
[[ 0.379358   -0.55901194  0.48704922  0.11619566  0.23902717  0.01691487
   0.07227738  0.14556988  0.2459927   0.2501198 ]], shape=(1, 10), dtype=float32)

I read the document https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization, but I still can't understand why is it? Help!


Solution

  • Why do batch_normalization produce all-zero output when training = True

    It's because your batch size = 1 here.

    Batch normalization layer normalizes its input by using batch mean and batch standard deviation for each channel.

    When the batch size is 1 and after flatten, there is only one single value in each channel, so that the batch mean(for that channel) will be the single value itself, thus outputting a zero tensor after the batch normalization layer.

    but produce non-zero output when training = False?

    During inference, batch normalization layer normalizes inputs by using moving average of batch mean and SD instead of current batch mean and SD.

    The moving mean and SD are initialized as zero and one respectively and updated gradually. Therefore, the moving mean doesn't equal that single value in each channel at the beginning, therefore the layer will not output a zero tensor.

    In conclusion: use batch size > 1 and input tensor with random values/realistic data values rather than tf.ones() in which all elements are the same.