Batch normalization in tensorflow: variables and performance

I would like to add conditional operations on the variables of a batch normalization layer. Specifically, train in float, then quantize in a fine-tuning secondary training phase. For this, I want to add a tf.cond operation on the variables (scale, shift and exp moving averages of mean and var).

I replaced the tf.layers.batch_normalization with a batchnorm layer I wrote (see below).

This function works perfectly (i.e. I get the same metrics with both functions), and I can add whatever pipeline to the variables (before the batchnorm operation). The problem is that the performance (runtime) dropped dramatically (i.e. there's a x2 factor by simply replacing the layers.batchnorm with my own function, as written below).

def batchnorm(self, x, name, epsilon=0.001, decay=0.99):
    epsilon = tf.to_float(epsilon)
    decay = tf.to_float(decay)
    with tf.variable_scope(name):
        shape = x.get_shape().as_list()
        channels_num = shape[3]
        # scale factor
        gamma = tf.get_variable("gamma", shape=[channels_num], initializer=tf.constant_initializer(1.0), trainable=True)
        # shift value
        beta = tf.get_variable("beta", shape=[channels_num], initializer=tf.constant_initializer(0.0), trainable=True)
        moving_mean = tf.get_variable("moving_mean", channels_num, initializer=tf.constant_initializer(0.0), trainable=False)
        moving_var = tf.get_variable("moving_var", channels_num, initializer=tf.constant_initializer(1.0), trainable=False)
        batch_mean, batch_var = tf.nn.moments(x, axes=[0, 1, 2]) # per channel

        update_mean = moving_mean.assign((decay * moving_mean) + ((1. - decay) * batch_mean))
        update_var = moving_var.assign((decay * moving_var) + ((1. - decay) * batch_var))

        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean)
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_var)

        bn_mean = tf.cond(self.is_training, lambda: tf.identity(batch_mean), lambda: tf.identity(moving_mean))
        bn_var = tf.cond(self.is_training, lambda: tf.identity(batch_var), lambda: tf.identity(moving_var))

        with tf.variable_scope(name + "_batchnorm_op"):
            inv = tf.math.rsqrt(bn_var + epsilon)
            inv *= gamma
            output = ((x*inv) - (bn_mean*inv)) + beta

    return output

I would appreciate help in any of the following questions:

  • Any ideas on how to improve the performance (reduce runtime) of my solution?
  • Is it possible to add my own operators to the variables pipeline of layers.batchnorm before the batchnorm operation?
  • Any other solution to the same problem?


  • tf.nn.fused_batch_norm is optimized and did the trick.

    I had to create two subgraphs, one per mode, since fused_batch_norm's interface does not take a conditional training/test mode (is_training is bool and not a tensor, so it's graph is not conditional). I added the condition after (see below). However, even with the two subgraphs, this has about the same runtime of tf.layers.batch_normalization.

    Here's the final solution (I'd still appreciate any comment or advice for improvements):

    def batchnorm(self, x, name, epsilon=0.001, decay=0.99):
        with tf.variable_scope(name):
            shape = x.get_shape().as_list()
            channels_num = shape[3]
            # scale factor
            gamma = tf.get_variable("gamma", shape=[channels_num], initializer=tf.constant_initializer(1.0), trainable=True)
            # shift value
            beta = tf.get_variable("beta", shape=[channels_num], initializer=tf.constant_initializer(0.0), trainable=True)
            moving_mean = tf.get_variable("moving_mean", channels_num, initializer=tf.constant_initializer(0.0), trainable=False)
            moving_var = tf.get_variable("moving_var", channels_num, initializer=tf.constant_initializer(1.0), trainable=False)
            (output_train, batch_mean, batch_var) = tf.nn.fused_batch_norm(x,
                                                                     beta,  # pylint: disable=invalid-name
            (output_test, _, _) = tf.nn.fused_batch_norm(x,
                                                         beta,  # pylint: disable=invalid-name
            output = tf.cond(self.is_training, lambda: tf.identity(output_train), lambda: tf.identity(output_test))
            update_mean = moving_mean.assign((decay * moving_mean) + ((1. - decay) * batch_mean))
            update_var = moving_var.assign((decay * moving_var) + ((1. - decay) * batch_var))
            tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean)
            tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_var)
        return output