Search code examples
tensorflowkerasbatch-normalization

BatchNormalization in Keras


How do I update moving mean and moving variance in keras BatchNormalization?

I found this in tensorflow documentation, but I don't know where to put train_op or how to work it with keras models:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize( loss )

No posts I found say what to do with train_op and whether you can use it in model.compile.


Solution

  • You do not need to manually update the moving mean and variances if you are using the BatchNormalization layer. Keras takes care of updating these parameters during training, and to keep them fixed during testing (by using the model.predict and model.evaluate functions, same as with model.fit_generator and friends).

    Keras also keeps track of the learning phase so different codepaths run during training and validation/testing.