Search code examples
machine-learningdeep-learningnlppytorchbert-language-model

BertModel and BertForMaskedLM weights count


I want understand BertForMaskedLM model, in huggingface github code, BertForMaskedLM is bert model with additional 2 linear layers with shape (input 768, output 768) and (input 768, output 30522). Count of all weights will be weights of BertModel + 768 * 768 + 768 * 30522, but when I check the numbers don't match.

from transformers import BertModel, BertForMaskedLM
import torch

bertmodel = BertModel.from_pretrained('bert-base-uncased')
bertForMaskedLM = BertForMaskedLM.from_pretrained('bert-base-uncased')

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(bertmodel)
#output 109482240
count_parameters(bertForMaskedLM)
#output 109514298

109482240 + 768 * 768 + 768 * 30522 != 109514298

what am I doing wrong?


Solution

  • Using numel() along with model.parameters() is not a reliable method for counting the total number of parameters and may fail for recursive configuration of layers. This is exactly what is happening in your case. Instead, try following:

    from torchinfo import summary
    
    print(summary(bertmodel))
    

    Output: enter image description here


    print(summary(bertForMaskedLM))
    

    Output: enter image description here

    From the above outputs we can see that total number of trainable params for the two models are:
    bertmodel: 109,482,240
    bertForMaskedLM: 132,955,194

    In order to understand the difference, lets have a look at the last module of both the models (rest of the base model is exactly the same):

    bertmodel:

    (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh())
    

    bertForMaskedLM:

    (cls): BertOnlyMLMHead((predictions): BertLMPredictionHead(
    (transform): BertPredictionHeadTransform(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    )
    (decoder): Linear(in_features=768, out_features=30522, bias=True)))
    

    Only additions are the LayerNorm layer (2 * 768 params for layer gammas and betas) and the decoder layer (769 * 30522, using the y=A*X + B, where A is of size (nxm) and B of (nx1) with a total params of nx(m+1).

    Params for bertForMaskedLM = 109482240 + 2 * 768 + 769 * 30522 = 132955194