Consider a batch of sentences with different lengths.
When using the BertTokenizer, I apply padding so that all the sequences have the same length and we end up with a nice tensor of shape (bs, max_seq_len)
.
After applying the BertModel, I get a last hidden state of shape (bs, max_seq_len, hidden_sz)
.
My goal is to get the mean-pooled sentence embedding for each sentence (resulting in something with shape (bs, hidden_sz)
), but excluding the embeddings for the PAD tokens when taking the mean.
Is there a way to do this efficiently without looping over each sequence in the batch?
Thanks!
You can pad with Nan
and then use torch.nanmean.
You can then change the values back to something less likely to cause gradient issues down the line.
mean_pooled = torch.nanmean(hidden_state,dim = 1)
hidden_state = torch.nan_to_num(hidden_state,nan = 0)
Alternatively, take the sum of the row and divide by the number of non-zero (assuming 0 padding) elements.
mean_pooled = torch.sum(hidden_state,dim = 1) / torch.where(hidden_state != 0, 1,0).sum(dim = 1)