Search code examples
pythondeep-learningnlpfeature-extractionbert-language-model

extract and concanate the last 4 hidden states from bert model for each input


I want to extract and concanate 4 last hidden states from bert for each input sentance and save them I use this code but i got last hidden state only

class MixModel(nn.Module):
    def __init__(self,pre_trained='bert-base-uncased'):
        super().__init__()        
        self.bert =  AutoModel.from_pretrained('distilbert-base-uncased')
        self.hidden_size = self.bert.config.hidden_size
        
      
           
    def forward(self,inputs, mask , labels):
        
        cls_hs = self.bert(input_ids=inputs,attention_mask=mask, return_dict= False,  output_hidden_states=True)        
        print(cls_hs)        
                   
        encoded_layers = cls_hs[0]
        print(len(encoded_layers))

        print(encoded_layers.size())
        #output is [1,64,768]
       
        return encoded_layers

batch size is 1 padding size is 64

how to extract the last four?


Solution

  • grab the last 4 hidden states, now a tuple of 4 tensors of shape (batch_size, seq_len, hidden_size)

    encoded_layers = cls_hs['hidden_states'][-4:]
    

    and concatenate them (here over the last dimension) to a single tensor of shape (batch_size, seq_len, 4 * hidden_size)

    concatenated = torch.cat(encoded_layers, -1)