I want to extract and concanate 4 last hidden states from bert for each input sentance and save them I use this code but i got last hidden state only
class MixModel(nn.Module):
def __init__(self,pre_trained='bert-base-uncased'):
super().__init__()
self.bert = AutoModel.from_pretrained('distilbert-base-uncased')
self.hidden_size = self.bert.config.hidden_size
def forward(self,inputs, mask , labels):
cls_hs = self.bert(input_ids=inputs,attention_mask=mask, return_dict= False, output_hidden_states=True)
print(cls_hs)
encoded_layers = cls_hs[0]
print(len(encoded_layers))
print(encoded_layers.size())
#output is [1,64,768]
return encoded_layers
batch size is 1 padding size is 64
how to extract the last four?
grab the last 4 hidden states, now a tuple of 4 tensors of shape (batch_size, seq_len, hidden_size)
encoded_layers = cls_hs['hidden_states'][-4:]
and concatenate them (here over the last dimension) to a single tensor of shape (batch_size, seq_len, 4 * hidden_size)
concatenated = torch.cat(encoded_layers, -1)