I am trying to do classification task and I got last 4 layers from BERT and concatenate them.
out = model(...)
out=torch.cat([out['hidden_states'][-i] for i in range(1,5)],dim=-1)
Now the shape is (12,200,768*4)
which is batch,max_length,concatenation layer
but for fully connected layer we need to have two dimension. So one way is to average like torch.mean((12,200,768*4),dim=1)
and get the output as (12,768*4)
.
But i am confused what is the original BERT approach
There is no "original" BERT approach for classification with concatenated hidden layers. You have several options to proceed and I will just describe a comment on your approach and suggest an alternative in the following.
Preliminary:
import torch.nn as nn
from transformers import BertTokenizerFast, BertModel
t = BertTokenizerFast.from_pretrained("bert-base-cased")
m = BertModel.from_pretrained("bert-base-cased")
fc = nn.Linear(768, 5)
s = ["This is a random sentence", "This is another random sentence with more words"]
i = t(s, padding=True,return_tensors="pt")
with torch.no_grad():
o = m(**i, output_hidden_states=True)
print(i)
At first, you should look at your input:
#print(I)
{'input_ids':
tensor([[ 101, 1188, 1110, 170, 7091, 5650, 102, 0, 0, 0],
[ 101, 1188, 1110, 1330, 7091, 5650, 1114, 1167, 1734, 102]]),
'token_type_ids':
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
'attention_mask':
tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
}
What you should notice here, is that the shorter sentence gets padded. That is relevant because simply pooling the mean with torch.mean
, will result in different sentence embeddings for the same sentence depending on the number of padding tokens. Of course, the model will learn to handle that to some extent after sufficient training, but you should, however, use a more sophisticated mean function that removes the padding tokens right away :
def mean_pooling(model_output, attention_mask):
input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float()
return torch.sum(model_output * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
o_mean = [mean_pooling(o.hidden_states[-x],i.attention_mask) for x in range(1,5)]
#we want a tensor and not a list
o_mean = torch.stack(o_mean, dim=1)
#we want only one tensor per sequence
o_mean = torch.mean(o_mean,dim=1)
print(o_mean.shape)
with torch.no_grad():
print(fc(o_mean))
Output:
torch.Size([2, 768])
tensor([[ 0.0677, -0.0261, -0.3602, 0.4221, 0.2251],
[-0.0328, -0.0161, -0.5209, 0.5825, 0.2405]])
These operations are pretty expensive and people often use an approach called cls pooling as a cheaper alternative with comparable performance:
#We only use the cls token (i.e. first token of the sequence)
#id 101
o_cls = [o.hidden_states[-x][:, 0] for x in range(1,5)]
#we want a tensor and not a list
o_cls = torch.stack(o_cls, dim=1)
#we want only one tensor per sequence
o_cls = torch.mean(o_cls,dim=1)
print(o_cls.shape)
with torch.no_grad():
print(fc(o_cls))
Output:
torch.Size([2, 768])
tensor([[-0.3731, 0.0473, -0.4472, 0.3804, 0.4057],
[-0.3468, 0.0685, -0.5885, 0.4994, 0.4182]])