Search code examples
deep-learningnlppytorchtorchhuggingface-transformers

Train n% last layers of BERT in Pytorch using HuggingFace Library (train Last 5 BERTLAYER out of 12 .)


Bert has an Architecture something like encoder -> 12 BertLayer -> Pooling. I want to train the last 40% layers of Bert Model. I can freeze all the layers as:

# freeze parameters
bert = AutoModel.from_pretrained('bert-base-uncased')
for param in bert.parameters():
    param.requires_grad = False

But I want to Train last 40% layers. When I do len(list(bert.parameters())), it gives me 199. So let us suppose 79 is the 40% of parameters. Can I do something like:

for param in list(bert.parameters())[-79:]: # total  trainable 199 Params: 79 is 40%
    param.requires_grad = False

I think it will freeze first 60% layers.

Also, can someone tell me that which layers it will freeze according to architecture?


Solution

  • You are probably looking for named_parameters.

    for name, param in bert.named_parameters():                                            
        print(name)
    

    Output:

    embeddings.word_embeddings.weight
    embeddings.position_embeddings.weight
    embeddings.token_type_embeddings.weight
    embeddings.LayerNorm.weight
    embeddings.LayerNorm.bias
    encoder.layer.0.attention.self.query.weight
    encoder.layer.0.attention.self.query.bias
    encoder.layer.0.attention.self.key.weight
    ...
    

    named_parameters will also show you that you have not frozen the first 60% but the last 40%:

    for name, param in bert.named_parameters():
        if param.requires_grad == True:
            print(name) 
    

    Output:

    embeddings.word_embeddings.weight
    embeddings.position_embeddings.weight
    embeddings.token_type_embeddings.weight
    embeddings.LayerNorm.weight
    embeddings.LayerNorm.bias
    encoder.layer.0.attention.self.query.weight
    encoder.layer.0.attention.self.query.bias
    encoder.layer.0.attention.self.key.weight
    encoder.layer.0.attention.self.key.bias
    encoder.layer.0.attention.self.value.weight
    ...
    

    You can freeze the first 60% with:

    for name, param in list(bert.named_parameters())[:-79]: 
        print('I will be frozen: {}'.format(name)) 
        param.requires_grad = False