Search code examples
tensorflowneural-networkbatch-normalization

In tensorflow 2, does using model.fit automatically set the "training" flag in a BatchNorm layer?


I believe that when using batch normalization layers in tensorflow, it is important to set the training flag when using it, and set it to False on validation data, and True when training. Is this correct? If so, does model.fit automatically infer this flag or should I set it manually? If so, how?


Solution

  • In case of TensorFlow 2, you don't have to explicitly set the training flag as you will pass it in model.fit. So, just use batch normalization normally like:

    tf.keras.layers.BatchNormalization()