Search code examples
python-3.xkerastensorflow2.0huggingface-transformersbert-language-model

How to freeze some layers of BERT in fine tuning in tf2.keras


I am trying to fine-tune 'bert-based-uncased' on a dataset for a text classification task. Here is the way I am downloading the model:

import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=num_labels)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

As bert-base has 12 layers, I wanted to just fine-tune the last 2 layers to prevent overfitting. model.layers[i].trainable = False will not help. Because model.layers[0] gives the whole bert base model and if I set the trainable parameter to False, then all layers of bert will be frozen. Here is the architecture of model:

Model: "tf_bert_for_sequence_classification"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 bert (TFBertMainLayer)      multiple                  109482240 
                                                                 
 dropout_37 (Dropout)        multiple                  0         
                                                                 
 classifier (Dense)          multiple                  9997      
                                                                 
=================================================================
Total params: 109,492,237
Trainable params: 109,492,237
Non-trainable params: 0
_________________________________________________________________

Also, I wanted to use model.layers[0].weights[j]._trainable = False; but weights list has 199 elements in shape of TensorShape([30522, 768]). So I could not figure out that which weights are related to the last 2 layers. Can any-one help me to fix this?


Solution

  • I found the answer and I share it here. Hope it can help others. By the help of this article, which is about fine tuning bert using pytorch, the equivalent in tensorflow2.keras is as below:

    model.bert.encoder.layer[i].trainable = False
    

    where i is the index of the proper layer.