Search code examples
tensorflowhuggingface-transformers

How to freeze TFBertForSequenceClassification pre trained model?


If I am using the tensorflow version of huggingface transformer, how do I freeze the weights of the pretrained encoder so that only the weights of the head layer are optimized?

For the PyTorch implementation, it is done through

for param in model.base_model.parameters():
    param.requires_grad = False

Would like to do the same for tensorflow implementation.


Solution

  • Found a way to do it. Freeze the base model BEFORE compiling it.

    model = TFBertForSequenceClassification.from_pretrained("bert-base-uncased")
    model.layers[0].trainable = False
    model.compile(...)