Search code examples
pythonmachine-learningdeep-learningnlphuggingface-transformers

How to calculate the weighted sum of last 4 hidden layers using Roberta?


The table from this paper that explains various approaches to obtain the embedding, I think these approaches are also applicable to Roberta too:

enter image description here

I'm trying to calculate the weighted sum of last 4 hidden layers using Roberta to obtain token embedding, but I don't know if this is the correct way to do, this is the code I have tried:

from transformers import RobertaTokenizer, RobertaModel
import torch

tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('roberta-base')
caption = ['this is a yellow bird', 'example caption']

tokens = tokenizer(caption, return_tensors='pt', padding=True)

input_ids = tokens['input_ids']
attention_mask = tokens['attention_mask']

output = model(input_ids, attention_mask, output_hidden_states=True)

states = output.hidden_states
token_emb = torch.stack([states[i] for i in [-4, -3, -2, -1]]).sum(0).squeeze()

Solution

  • First, lets do some digging from the OG BERT code, https://github.com/google-research/bert

    If we just do a quick search for "sum" on the github repo, we find this https://github.com/google-research/bert/blob/eedf5716ce1268e56f0a50264a88cafad334ac61/modeling.py#L814

      # The Transformer performs sum residuals on all layers so the input needs
      # to be the same as the hidden size.
      if input_width != hidden_size:
        raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
                         (input_width, hidden_size))
    

    Then a quick search on Stackoverflow reveals How to get intermediate layers' output of pre-trained BERT model in HuggingFace Transformers library?


    Now, lets validate if your code logic works by working backwards a little:

    from transformers import RobertaTokenizer, RobertaModel
    import torch
    
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    model = RobertaModel.from_pretrained('roberta-base')
    caption = ['this is a yellow bird', 'example caption']
    
    tokens = tokenizer(caption, return_tensors='pt', padding=True)
    
    input_ids = tokens['input_ids']
    attention_mask = tokens['attention_mask']
    
    output = model(input_ids, attention_mask, output_hidden_states=True)
    

    Then:

    >>>len(output.hidden_states)
    

    Out:

    13
    

    Why 13?

    12 Encoder (hidden) layer output + final pooler output

    import torchinfo
    torchinfo.summary(model)
    

    [out]:

    ================================================================================
    Layer (type:depth-idx)                                  Param #
    ================================================================================
    RobertaModel                                            --
    ├─RobertaEmbeddings: 1-1                                --
    │    └─Embedding: 2-1                                   38,603,520
    │    └─Embedding: 2-2                                   394,752
    │    └─Embedding: 2-3                                   768
    │    └─LayerNorm: 2-4                                   1,536
    │    └─Dropout: 2-5                                     --
    ├─RobertaEncoder: 1-2                                   --
    │    └─ModuleList: 2-6                                  --
    │    │    └─RobertaLayer: 3-1                           7,087,872
    │    │    └─RobertaLayer: 3-2                           7,087,872
    │    │    └─RobertaLayer: 3-3                           7,087,872
    │    │    └─RobertaLayer: 3-4                           7,087,872
    │    │    └─RobertaLayer: 3-5                           7,087,872
    │    │    └─RobertaLayer: 3-6                           7,087,872
    │    │    └─RobertaLayer: 3-7                           7,087,872
    │    │    └─RobertaLayer: 3-8                           7,087,872
    │    │    └─RobertaLayer: 3-9                           7,087,872
    │    │    └─RobertaLayer: 3-10                          7,087,872
    │    │    └─RobertaLayer: 3-11                          7,087,872
    │    │    └─RobertaLayer: 3-12                          7,087,872
    ├─RobertaPooler: 1-3                                    --
    │    └─Linear: 2-7                                      590,592
    │    └─Tanh: 2-8                                        --
    ================================================================================
    Total params: 124,645,632
    Trainable params: 124,645,632
    Non-trainable params: 0
    ================================================================================
    

    To validate that the last layer output is the last layer in the hidden_states:

    assert(
      True for x in 
        torch.flatten(
          output[0] == output.hidden_states[-1]
        )
    )
    

    Lets check if the size for each layer's output matches:

    first_hidden_shape = output.hidden_states[0].shape
    
    for x in output.hidden_states:
      assert x.shape == first_hidden_shape
    

    Checks out!

    >>>first_hidden_shape
    

    [out]:

    torch.Size([2, 7, 768])
    

    Why [2, 7, 768]?

    It's (batch_size, sequence_length, hidden_size)

    • 2 sentences = batch size of 2
    • 7 longest sequence length = no. of tokens (i.e. 5 in the case of your longest example sentence + <s> and </s> from len(input_ids[0]))
    • 768 outputs = fixed for all hidden layers output

    Bíddu aðeins! (Wait a minute!), does that mean I've sequence_length * 768 outputs for each batch? And if my batches are not equal lengths, the output size are different?

    Yes that is correct! And to get some sense of "equality" for all inputs, it'll be good to pad/truncate all outputs to a fixed length if you're still going to use the feature-based BERT approaches.

    Soooo, is my torch.stack approach right?

    Yes, it seems so, but it depends on whether you consider the pooler output to be last or second to last.

    If second to last:

    torch.stack(output.hidden_states[-5:-1]).sum(0)
    

    if you consider the pooler to be the last:

    torch.stack(output.hidden_states[-4:]).sum(0)
    

    Minor nitpicking, you can access the output.hidden_states through slices because it's a tuple object. Next you won't need to squeeze the stacked output because the the outer most layer tensor is non-empty.

    This is a special case for stack, in NLP where the 1st dimension is batch size and 2nd is token length, so summing the hidden dimensions up ends up the same when you're not explicitly stating which dimension you stack.

    To be a little more explicit:

    # 2nd dimension is where our hidden states are 
    # and that's where we want to do our sum too.
    torch.stack(output.hidden_states[-4:], dim=2).sum(2)
    

    But in practice, you can do this to comfort yourself:

    assert(
     True for x in torch.flatten(
        torch.stack(output.hidden_states[-4:]).sum(0)
        == 
        torch.stack(output.hidden_states[-4:], dim=2).sum(2)
      )
    )
    

    Interesting, what about "concat last four hidden"?

    In the case of concat you'll need to be explicit when in the dimensions

    >>> torch.cat(output.hidden_states[-4:], dim=2).shape
    

    [out]:

    torch.Size([2, 7, 3072])
    

    but note, you still ends with sequence_length * hidden_size * 4, which makes batches with unequal lengths a pain.

    Since you've covered almost everything on the table, what about the "embeddings" output?

    This is the interesting part, it's actually not accessible through the model(inputs_ids) directly, you'll need to do this:

    model.embeddings(input_ids)
    

    Finally, why didn't you just answer "yes, you are right"?

    If I did, would that convince you more than you proving the above for yourself?