Search code examples
tensorflowtensorboard

Batch Normalization in a Custom Estimator in Tensorflow


I'm referring to a Note at tf.layers.batch_normilization:

Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. For example:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)

How would one implement this in a Custom Estimator? For example looking at this example on Tensorflow's website: The complete abalone model_fn


Solution

  • I guess you can pass the train_op you refer to the train_op parameter of the EstimatorSpec.