Search code examples
tensorflowbert-language-modeltensorflow-hub

Freezing of BERT layers while using tfhub module


In this link click here the author says that:

import tensorflow_hub as hub
module = hub.Module(<<Module URL as string>>, trainable=True)

If user wishes to fine-tune/modify the weights of the model, this parameter has to be set as True. So my doubt is if I set this to false does it mean that I am freezing all the layers of the BERT which is my intension too. I want to know if my approach is right.


Solution

  • I have a multi-part answer for you.

    How to freeze a Module

    It all comes down to how your optimizer is set up. The usual approach for TF1 is to initialize it with all Variables found in the TRAINABLE_VARIABLES collection. The doc for hub.Module says about trainable: "If False, no variables are added to TRAINABLE_VARIABLES collection, ...". So, yes, setting trainable=False (explicitly or by default) freezes the module in the standard usage of TF1.

    Why not to freeze BERT

    That said, BERT is meant to be fine-tuned. The paper talks about the feature-based (i.e., frozen) vs fine-tuning approaches in more general terms, but the module doc spells it out clearly: "fine-tuning all parameters is the recommended practice." This gives the final parts of computing the pooled output a better shot at adapting to the features that matter most for the task at hand.

    If you intend to follow this advice, please also mind tensorflow.org/hub/tf1_hub_module#fine-tuning and pick the correct graph version: BERT uses dropout regularization during training, and you need to set hub.Module(..., tags={"train"}) to get that. But for inference (in evaluation and prediction), where dropout does nothing, you omit the tags= argument (or set it to the empty set() or to None).

    Outlook: TF2

    You asked about hub.Module(), which is an API for TF1, so I answered in that context. The same considerations apply for BERT in TF2 SavedModel format. There, it's all about setting hub.KerasLayer(..., trainable=True) or not, but the need to select a graph version has gone away (the layer picks up Keras' training state and applies it under the hood).

    Happy training!