Search code examples

How to replace PyTorch model layer's tensor with another layer of same shape in Huggingface model?

Given a Huggingface model, e.g.

from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-large-uncased", num_labels=2)

I can access a layer's tensor as such:

# Shape [1024, 1024]


tensor([[ 0.0167, -0.0422, -0.0425,  ...,  0.0302, -0.0341,  0.0251],
        [ 0.0323,  0.0347, -0.0041,  ..., -0.0722,  0.0031, -0.0351],
        [ 0.0387, -0.0293, -0.0694,  ...,  0.0492,  0.0201, -0.0727],
        [ 0.0035,  0.0081, -0.0337,  ...,  0.0460,  0.0268,  0.0747],
        [ 0.0513,  0.0131,  0.0735,  ..., -0.0127,  0.0144, -0.0400],
        [ 0.0385,  0.0013, -0.0272,  ...,  0.0148,  0.0399,  0.0339]])

Given the another tensor of the same shape that I've pre-defined from somewhere else, in this case, for illustration, I'm creating a random tensor but this can be any tensor that is pre-defined.

import torch
replacement_layer = torch.rand([1024, 1024])

Note: I'm not trying to replace a layer with a random tensor but replace it with a pre-defined one.

When I try to do this to replace the layer tensor through the state_dict(), it didn't seem to work:

import torch
from transformers import AutoModelForSequenceClassification

# The model with a layer that we want to replace.
model = AutoModelForSequenceClassification.from_pretrained("bert-large-uncased", num_labels=2)

# A replacement layer.
replacement_layer = torch.rand([1024, 1024])

# Replacing the layer in the statedict.
model.state_dict()["bert.encoder.layer.0.attention.self.query.weight"] = replacement_layer

# Check that the layer is replaced. No, it is not =(
assert torch.equal(

How to replace PyTorch model layer's tensor with another layer of same shape in Huggingface model?


  • A state_dict is something special. It is an on-the-fly copy more than it is the actual contents of a model, if that makes sense.

    You can directly access a model's layers by dot notation. Note that 0 often indicates an index rather than a string. You'll also need to transform your tensor into a torch Parameter for it to work within a model.

    So this should work:

    model.bert.encoder.layer[0].attention.self.query.weight = torch.nn.Parameter(replacement_layer)

    or in full:

    # Note I used the base model for testing
    import torch
    from transformers import AutoModelForSequenceClassification
    # The model with a layer that we want to replace.
    model: torch.nn.Module = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
    # A replacement layer.
    replacement_layer = torch.rand([768, 768])
    model.bert.encoder.layer[0].attention.self.query.weight = torch.nn.Parameter(replacement_layer)
    # Check that the layer is replaced
    assert torch.equal(
    assert torch.equal(