I am using a tensorflow neural net to figure out how batch normalization works and replicate it in my own library. I've run into this strange issue:
When you initialize a neural net layer, all biases (or in case of batchnorm - betas) are set to 0, so the layer should just multiply the input values by the weights, and that's about it. Now, from what I understand about batchnorm, during training it calculates the means and the variances for the layer inputs based on the minibatch it is being fed, and then does this to the input: output = (input - mean) / sqrt(variance + eps).
So, if all the input values of your minibatch are the same, then during training batchnorm will subtract the mean (equal to each value) from the input value, so the net should output 0, regardless of input, right?
And, it doesn't. In fact, it looks like all the means during calculation are 0, and the variances are 1 as if it is using the running averages of those values. So, either I don't understand how batchnorm works or batchnorm is just used incorrectly. Here is how it is initialized in the code I'm using:
layer= tflearn.fully_connected(layer, 10, weights_init=w_init)
layer= tflearn.layers.normalization.batch_normalization(layer)
layer= tflearn.activations.leaky_relu(layer)
The other option is that it is used incorrectly during training, but I would like to eliminate the other possible explanations first.
The TensorFlow batch norm implementation has some update ops that are not included in the training op's dependencies by default. You have to add the dependencies explicitly. Quoting the docs:
[W]hen training, the
moving_mean
andmoving_variance
need to be updated. By default the update ops are placed intf.GraphKeys.UPDATE_OPS
, so they need to be added as a dependency to thetrain_op
. Also, be sure to add anybatch_normalization
ops before getting theupdate_ops
collection. Otherwise,update_ops
will be empty, and training/inference will not work properly. For example:
x_norm = tf.layers.batch_normalization(x, training=training)
# ...
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)