Search code examples
pythonneural-networknlppytorch

BERT sentence embedding by summing last 4 layers


I used Chris Mccormick tutorial on BERT using pytorch-pretained-bert to get a sentence embedding as follows:

tokenized_text = tokenizer.tokenize(marked_text)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [1] * len(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
model = BertModel.from_pretrained('bert-base-uncased')
model.eval()

with torch.no_grad():
    encoded_layers, _ = model(tokens_tensor, segments_tensors)
    # Holds the list of 12 layer embeddings for each token
    # Will have the shape: [# tokens, # layers, # features]
    token_embeddings = []

    # For each token in the sentence...
    for token_i in range(len(tokenized_text)):
        # Holds 12 layers of hidden states for each token
        hidden_layers = []

        # For each of the 12 layers...
        for layer_i in range(len(encoded_layers)):

                # Lookup the vector for `token_i` in `layer_i`
                vec = encoded_layers[layer_i][batch_i][token_i]

                hidden_layers.append(vec)

        token_embeddings.append(hidden_layers)

Now, I am trying to get the final sentence embedding by summing the last 4 layers as follows:

summed_last_4_layers = [torch.sum(torch.stack(layer)[-4:], 0) for layer in token_embeddings]

But instead of getting a single torch vector of length 768 I get the following:

[tensor([-3.8930e+00, -3.2564e+00, -3.0373e-01,  2.6618e+00,  5.7803e-01,
-1.0007e+00, -2.3180e+00,  1.4215e+00,  2.6551e-01, -1.8784e+00,
-1.5268e+00,  3.6681e+00, ...., 3.9084e+00]), tensor([-2.0884e+00, -3.6244e-01,  ....2.5715e+00]), tensor([ 1.0816e+00,...-4.7801e+00]), tensor([ 1.2713e+00,.... 1.0275e+00]), tensor([-6.6105e+00,..., -2.9349e-01])]

What did I get here? How do I pool the sum of the last for layers?

Thank you!


Solution

  • You create a list using a list comprehension that iterates over token_embeddings. It is a list that contains one tensor per token - not one tensor per layer as you probably thought (judging from your for layer in token_embeddings). You thus get a list with a length equal to the number of tokens. For each token, you have a vector that is a sum of BERT embeddings from the last 4 layers.

    More efficient would be avoiding the explicit for loops and list comprehenions:

    summed_last_4_layers = torch.stack(encoded_layers[-4:]).sum(0)
    

    Now, variable summed_last_4_layers contains the same data, but in the form of a single tensor of dimension: length of the sentence × 768.

    To get a single (i.e., pooled) vector, you can do pooling over the first dimension of the tensor. Max-pooling or average-pooling might make much more sense in this case than summing all the token embeddings. When summing the values, vectors of differently long sentences are in different ranges and are not really comparable.