In the code by Hugginface transformers, there are many fine-tuning models have the function init_weight
.
For example(here), there is a init_weight
function at last.
class BertForSequenceClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
As I know, it will call the following code
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
My question is If we are loading the pre-trained model, why do we need to initialize the weight for every module?
I guess I must be misunderstanding something here.
Have a look at the code for .from_pretrained()
. What actually happens is something like this:
_init_weights
function that you mention)This ensure that the layers that were not pretrained (e.g. in some cases the final classification layer) do get initialised in _init_weights
but don't get overridden.