Search code examples
pythonbert-language-modelhuggingface-transformers

Copy one layer's weights from one Huggingface BERT model to another


I have a pre-trained model which I load like so:

from transformers import BertForSequenceClassification, AdamW, BertConfig, BertModel
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
    num_labels = 2, # The number of output labels--2 for binary classification.
                    # You can increase this for multi-class tasks.   
    output_attentions = False, # Whether the model returns attentions weights.
    output_hidden_states = False, # Whether the model returns all hidden-states.
)

I want to create a new model with the same architecture, and random initial weights, except for the embedding layer:

==== Embedding Layer ====

bert.embeddings.word_embeddings.weight                  (30522, 768)
bert.embeddings.position_embeddings.weight                (512, 768)
bert.embeddings.token_type_embeddings.weight                (2, 768)
bert.embeddings.LayerNorm.weight                              (768,)
bert.embeddings.LayerNorm.bias                                (768,)

It seems I can do this to create a new model with the same architecture, but then all the weights are random:

configuration   = model.config
untrained_model = BertForSequenceClassification(configuration)

So how do I copy over model's embedding layer weights to the new untrained_model?


Solution

  • Weights and bias are just tensor and you can simply copy them with copy_:

    from transformers import BertForSequenceClassification, BertConfig
    jetfire = BertForSequenceClassification.from_pretrained('bert-base-cased')
    config = BertConfig.from_pretrained('bert-base-cased')
    
    optimus = BertForSequenceClassification(config)
    
    parts = ['bert.embeddings.word_embeddings.weight'
    ,'bert.embeddings.position_embeddings.weight'              
    ,'bert.embeddings.token_type_embeddings.weight'    
    ,'bert.embeddings.LayerNorm.weight'
    ,'bert.embeddings.LayerNorm.bias']
    
    def joltElectrify (jetfire, optimus, parts):
      target = dict(optimus.named_parameters())
      source = dict(jetfire.named_parameters())
    
      for part in parts:
        target[part].data.copy_(source[part].data)  
    
    joltElectrify(jetfire, optimus, parts)