Search code examples
pytorchnlphuggingface-transformers

How to efficiently mean-pool BERT embeddings while excluding padding?


Consider a batch of sentences with different lengths.

When using the BertTokenizer, I apply padding so that all the sequences have the same length and we end up with a nice tensor of shape (bs, max_seq_len).

After applying the BertModel, I get a last hidden state of shape (bs, max_seq_len, hidden_sz).

My goal is to get the mean-pooled sentence embedding for each sentence (resulting in something with shape (bs, hidden_sz)), but excluding the embeddings for the PAD tokens when taking the mean.

Is there a way to do this efficiently without looping over each sequence in the batch?

Thanks!


Solution

  • You can pad with Nan and then use torch.nanmean. You can then change the values back to something less likely to cause gradient issues down the line.

    mean_pooled = torch.nanmean(hidden_state,dim = 1)
    hidden_state = torch.nan_to_num(hidden_state,nan = 0)
    

    Alternatively, take the sum of the row and divide by the number of non-zero (assuming 0 padding) elements.

    mean_pooled = torch.sum(hidden_state,dim = 1) / torch.where(hidden_state != 0, 1,0).sum(dim = 1)