Search code examples
pythonhuggingface-transformersbert-language-model

Why we need the init_weight function in BERT pretrained model in Huggingface Transformers?


In the code by Hugginface transformers, there are many fine-tuning models have the function init_weight. For example(here), there is a init_weight function at last.

class BertForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

As I know, it will call the following code

def _init_weights(self, module):
    """ Initialize the weights """
    if isinstance(module, (nn.Linear, nn.Embedding)):
        # Slightly different from the TF version which uses truncated_normal for initialization
        # cf https://github.com/pytorch/pytorch/pull/5617
        module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
    elif isinstance(module, BertLayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
        module.bias.data.zero_()

My question is If we are loading the pre-trained model, why do we need to initialize the weight for every module?

I guess I must be misunderstanding something here.


Solution

  • Have a look at the code for .from_pretrained(). What actually happens is something like this:

    • find the correct base model class to initialise
    • initialise that class with pseudo-random initialisation (by using the _init_weights function that you mention)
    • find the file with the pretrained weights
    • overwrite the weights of the model that we just created with the pretrained weights where applicable

    This ensure that the layers that were not pretrained (e.g. in some cases the final classification layer) do get initialised in _init_weights but don't get overridden.