Search code examples
tensorflow2.0batch-normalization

On the use of Batch Normalization


I am trying to make sure that I am incorporating batch normalization layers into a model correctly.

The code snippet below illustrates what I am doing.

  1. Is this an appropriate use of batch normalization?
  2. At inference time, how can I access the moving averages in each batch normalization layer to make sure they are being loaded?

List item

import tensorflow.v1.compat as tf
from model import Model

# Sample batch normalization layer in the Model class
x_preBN = ...
x_postBN = tf.layers.batch_normalization(inputs=x_preBN,
                                         center=True,
                                         scale=True,
                                         momentum=0.9,
                                         training=(self.mode == 'train'))

# During training:
model = Model(mode='train')
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

with tf.Session() as sess:
  for it in range(max_iterations):
    # Training step + update of BN moving statistics
    sess.run([train_step, extra_update_ops], feed_dict=...)

    # Store checkpoint
    if ii % num_checkpoint_steps == 0:
        saver.save(sess,
                   os.path.join(model_dir, 'checkpoint'),
                   global_step=it)
        

# During inference:
model = Model(mode='eval')
with tf.Session() as sess:
  saver.restore(sess, os.path.join(model_dir, 'checkpoint-???'))
  acc = sess.run(model.accuracy, feed_dict=...)

Solution

  • Once the model has been instantiated, a list of all global variables can be obtained as

    model = Model(mode='eval')
    saver = tf.train.Saver()
    print(tf.global_variables())
    

    The batch normalization variables for a specific layer look like this: gamma and beta are trainable, whereas the moving statistics are not (and hence the need to specify the extra_update_ops during training).

    <tf.Variable 'unit_1_1/residual_only_activation/batch_normalization/gamma:0' shape=(16,) dtype=float32>,
    <tf.Variable 'unit_1_1/residual_only_activation/batch_normalization/beta:0' shape=(16,) dtype=float32>,
    <tf.Variable 'unit_1_1/residual_only_activation/batch_normalization/moving_mean:0' shape=(16,) dtype=float32>,
    <tf.Variable 'unit_1_1/residual_only_activation/batch_normalization/moving_variance:0' shape=(16,) dtype=float32>
    

    They can be accessed as usual:

    ma = <tf.Variable 'unit_1_1/residual_only_activation/batch_normalization/moving_mean:0' shape=(16,) dtype=float32>
    with tf.Session() as sess:
      saver.restore(sess, model_dir)
      print(sess.run(ma))