Search code examples
pythontensorflowkerastf.kerasbatch-normalization

What's the difference between attrubutes 'trainable' and 'training' in BatchNormalization layer in Keras Tensorfolow?


According to the official documents from tensorflow:

About setting layer.trainable = False on a `BatchNormalization layer:
The meaning of setting layer.trainable = False is to freeze the layer, i.e. its internal state will not change during training: its trainable weights will not be updated during fit() or train_on_batch(), and its state updates will not be run.
Usually, this does not necessarily mean that the layer is run in inference mode (which is normally controlled by the training argument that can be passed when calling a layer). "Frozen state" and "inference mode" are two separate concepts.
However, in the case of the BatchNormalization layer, setting trainable = False on the layer means that the layer will be subsequently run in inference mode (meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).
This behavior has been introduced in TensorFlow 2.0, in order to enable layer.trainable = False to produce the most commonly expected behavior in the convnet fine-tuning use case.

I don't quite understand the term 'frozen state' and 'inference mode' here in the concept. I tried fine-tuning by setting the trainable to False, and I found that the moving mean and moving variance are not being updated.

So I have the following questions:

  1. What's the difference between 2 attributes training and trainable?
  2. Is gamma and beta getting updated in the training process if set trainable to false?
  3. Why is it necessary to set trainable to false when fine-tuning?

Solution

  • What's the difference between 2 attributes training and trainable?

    trainable:- ( If True ) It basically implies that the "trainable" weights of the parameter( of the layer ) will be updated in backpropagation.

    training:- Some layers perform differently at training and inference( or testing ) steps. Some examples include Dropout Layer, Batch-Normalization layers. So this attribute tells the layer that in what manner it should perform.

    Is gamma and beta getting updated in the training process if set trainable to false?

    Since gamma and beta are "trainable" parameters of the BN Layer, they will NOT be updated in the training process if set trainable is set to "False".

    Why is it necessary to set trainable to false when fine-tuning?

    When doing fine-tuning, we first add our own classification FC layer at the top which is randomly initialized but our "pre-trained" model is already calibrated( a bit ) for the task.

    As an analogy, think like this.

    You have a number line from 0 - 10. On this number line, '0' represents a completely randomized model whereas '10' represents a kind of perfect model. Our pre-trained model is somewhere around 5 or maybe 6 or maybe 7 i.e. most probably better than a random model. The FC Layer we have added at the top is at '0' as it is randomized at the start.

    We set trainable = False for the pre-trained model so that we can make the FC Layer reach the level of the pre-trained model rapidly i.e. with a higher learning rate. If we don't set trainable = False for the pre-trained model and use a higher learning rate then it will wreak havoc.

    So initially, we set a higher learning rate and trainable = False for the pre-trained model and train the FC layer. After that, we unfreeze our pre-trained model and use a very low learning rate to serve our purpose.

    Do freely ask for more clarification if required and upvote if you find it helpful.