Search code examples
pythonpytorchnlpbert-language-modelembedding

Calculating embedding overload problems with BERT


I'm trying to calculate the embedding of a sentence using BERT. After I input the sentence into BERT, I calculate the Mean-pooling, which is used as the embedding of the sentence.

Problem

My code can calculate the embedding of sentences, but the computational cost is very high. I don't know what's wrong and I hope someone can help me.

Install BERT

import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")

Get Embedding Function

# get the word embedding from BERT
def get_word_embedding(text:str):
    input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)  # Batch size 1
    outputs = model(input_ids)
    last_hidden_states = outputs[1]  
    # The last hidden-state is the first element of the output tuple
    return last_hidden_states[0]

Data

The maximum number of words in the text is 50. I calculate the entity+text embedding

enter image description here

Run code

entity_desc is my data. It's this step that overloads my computer every time I run it. Please help me!!!

I was use RAM 80GB machine in Colab.

entity_embedding = {}
for i in range(len(entity_desc)):
    entity = entity_desc['entity'][i]
    text = entity_desc['text'][i]
    entity += ' ' + text
    entity_embedding[entity_desc['entity_id'][i]] = get_word_embedding(entity)

Solution

  • I fixed the problem. The reason for the memory overload was that I wasn't saving the tensor to the GPU, so I made the following changes to the code.

    model = model.to(device)
    
    
    import torch
    # get the word embedding from BERT
    def get_word_embedding(text:str):
        input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)  # Batch size 1
        input_ids = input_ids.to(device)
    
        outputs = model(input_ids)
        last_hidden_states = outputs[1]
        last_hidden_states = last_hidden_states.to(device)  
        # The last hidden-state is the first element of the output tuple
        return last_hidden_states[0].detach().to(device)