Search code examples
tensorflowmachine-learningbatch-normalizationtensorflow-estimator

Updating batch_normalization mean & variance using Estimator API


The documentation isn't 100% clear on this:

Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. For example:

(see https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization)

Does that mean that all that is needed to save the moving_mean and moving_variance is the following?

def model_fn(features, labels, mode, params):
   training = mode == tf.estimator.ModeKeys.TRAIN
   extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

   x = tf.reshape(features, [-1, 64, 64, 3])
   x = tf.layers.batch_normalization(x, training=training)

   # ...

  with tf.control_dependencies(extra_update_ops):
     train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())

In other words, does simply using

with tf.control_dependencies(extra_update_ops):

take care of saving the moving_mean and moving_variance?


Solution

  • As it turns out, those values can get saved automatically. The edge case is that if you get the update ops collection before adding the batch normalization op to the graph, the update collection will be empty. This had not been documented before, but is now.

    The caveat when using batch_norm is to call tf.get_collection(tf.GraphKeys.UPDATE_OPS) after you've called tf.layers.batch_normalization.