Search code examples
pythontensorflowkeras

How to use the tf.keras.layers.BatchNormalization() in custom training loop?


I went back to tensorflow after quite a while and it seems the landscape is completely changed.

However, previously I used to use tf.contrib....batch_normalization with the following in the training loop:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(cnn.loss, global_step=global_step)

But it seems, contrib is nowhere to be found and tf.keras.layers.BatchNormalization does not work the same way. Also, I couldn't find any training instruction in their documentation.

So, any information of help is appreciated.


Solution

  • I started using pyTorch. It solved the problem.