Search code examples
pythontensorflowkerastensorflow2.0batch-normalization

Why it's necessary to frozen all inner state of a Batch Normalization layer when fine-tuning


The following content comes from Keras tutorial

This behavior has been introduced in TensorFlow 2.0, in order to enable layer.trainable = False to produce the most commonly expected behavior in the convnet fine-tuning use case.

Why we should freeze the layer when fine-tuning a convolutional neural network? Is it because some mechanisms in tensorflow keras or because of the algorithm of batch normalization? I run an experiment myself and I found that if trainable is not set to false the model tends to catastrophic forgetting what has been learned before and returns very large loss at first few epochs. What's the reason for that?


Solution

  • During training, varying batch statistics act as a regularization mechanism that can improve ability to generalize. This can help to minimize overfitting when training for a high number of iterations. Indeed, using a very large batch size can harm generalization as there is less variation in batch statistics, decreasing regularization.

    When fine-tuning on a new dataset, batch statistics are likely to be very different if fine-tuning examples have different characteristics to examples in the original training dataset. Therefore, if batch normalization is not frozen, the network will learn new batch normalization parameters (gamma and beta in the batch normalization paper) that are different to what the other network paramaters have been optimised for during the original training. Relearning all the other network parameters is often undesirable during fine-tuning, either due to the required training time or small size of the fine-tuning dataset. Freezing batch normalization avoids this issue.