I am trying to fine-tune 'bert-based-uncased' on a dataset for a text classification task. Here is the way I am downloading the model:
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=num_labels)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
As bert-base has 12 layers, I wanted to just fine-tune the last 2 layers to prevent overfitting. model.layers[i].trainable = False
will not help. Because model.layers[0]
gives the whole bert base model and if I set the trainable
parameter to False
, then all layers of bert will be frozen. Here is the architecture of model
:
Model: "tf_bert_for_sequence_classification"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
bert (TFBertMainLayer) multiple 109482240
dropout_37 (Dropout) multiple 0
classifier (Dense) multiple 9997
=================================================================
Total params: 109,492,237
Trainable params: 109,492,237
Non-trainable params: 0
_________________________________________________________________
Also, I wanted to use model.layers[0].weights[j]._trainable = False
; but weights
list has 199 elements in shape of TensorShape([30522, 768])
. So I could not figure out that which weights are related to the last 2 layers.
Can any-one help me to fix this?
I found the answer and I share it here. Hope it can help others. By the help of this article, which is about fine tuning bert using pytorch, the equivalent in tensorflow2.keras is as below:
model.bert.encoder.layer[i].trainable = False
where i is the index of the proper layer.