Search code examples
pythonnlppytorchhuggingface-transformers

Dropping layers in Transformer models (PyTorch / HuggingFace)


I came across this interesting paper on layers dropping in Transformer models and I am actually trying to implement it. However, I am wondering what would be a good practice to perform "layer dropping".

I have have a couple of ideas but have no idea what would be the cleanest/safest way to go here:

  • masking the unwanted layers (some sort of pruning)
  • copying the wanted layers into a new model

If anyone has already done this before or has suggestion I'm all ears!

Cheers


Solution

  • I think one of the safest ways would be simply to skip the given layers in the forward pass.

    For example, suppose you are using BERT and that you added the following entry to the config:

    config.active_layers = [False, True] * 6  # using a 12 layers model
    

    Then you could modify the BertEncoder class like the following:

    class BertEncoder(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.config = config
            self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
            self.gradient_checkpointing = False
    
        def forward(
            self,
            hidden_states,
            attention_mask=None,
            head_mask=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            past_key_values=None,
            use_cache=None,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True,
        ):
            all_hidden_states = () if output_hidden_states else None
            all_self_attentions = () if output_attentions else None
            all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
    
            next_decoder_cache = () if use_cache else None
            for i, layer_module in enumerate(self.layer):
                
                ########### MAGIC HERE #############
                if not self.config.active_layers[i]:
                    continue
                
                if output_hidden_states:
                    all_hidden_states = all_hidden_states + (hidden_states,)
    
                layer_head_mask = head_mask[i] if head_mask is not None else None
                past_key_value = past_key_values[i] if past_key_values is not None else None
    
                if self.gradient_checkpointing and self.training:
    
                    if use_cache:
                        logger.warning(
                            "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                        )
                        use_cache = False
    
                    def create_custom_forward(module):
                        def custom_forward(*inputs):
                            return module(*inputs, past_key_value, output_attentions)
    
                        return custom_forward
    
                    layer_outputs = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(layer_module),
                        hidden_states,
                        attention_mask,
                        layer_head_mask,
                        encoder_hidden_states,
                        encoder_attention_mask,
                    )
                else:
                    layer_outputs = layer_module(
                        hidden_states,
                        attention_mask,
                        layer_head_mask,
                        encoder_hidden_states,
                        encoder_attention_mask,
                        past_key_value,
                        output_attentions,
                    )
    
                hidden_states = layer_outputs[0]
                if use_cache:
                    next_decoder_cache += (layer_outputs[-1],)
                if output_attentions:
                    all_self_attentions = all_self_attentions + (layer_outputs[1],)
                    if self.config.add_cross_attention:
                        all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
    
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
    
            if not return_dict:
                return tuple(
                    v
                    for v in [
                        hidden_states,
                        next_decoder_cache,
                        all_hidden_states,
                        all_self_attentions,
                        all_cross_attentions,
                    ]
                    if v is not None
                )
            return BaseModelOutputWithPastAndCrossAttentions(
                last_hidden_state=hidden_states,
                past_key_values=next_decoder_cache,
                hidden_states=all_hidden_states,
                attentions=all_self_attentions,
                cross_attentions=all_cross_attentions,
            )
    

    At the moment you may need to write your special BERT class using the new Encoder layer. However, you should be able to load the weights from the pre-trained models provided by huggingface.

    BertEncoder code taken from here