Search code examples
cachinghuggingface-transformerstorchbert-language-modeltransformer-model

Why is BERT Storing cache even after Caching is disabled?


I am trying to extract hidden state features from a fine-tuned BERT model, but each text entry consumes memory and does not free it up after the next call. I can only run 20-30 sentences with 24 GB of ram memory.

from transformers import BertTokenizer, BertModel
import numpy as np

data = pd.read_csv('https://docs.google.com/spreadsheets/d/' + 
                   '1cFyUJdpFC3gpQsqjNc4D8ZCxBAMd_Pcpu8SlrsjAv-Q' +
                   '/export?gid=0&format=csv',
                  )
data = data.MESSAGES


# I will be using my own fine-tuned model, but with
# bert-base-uncased, I get the same problem
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained(
    "bert-base-cased",
    from_tf=True,
    output_hidden_states=True,
    use_cache=False)

sentences = data[0:].tolist()
inputs = tokenizer(sentences, return_tensors='pt', padding=True,truncation=True)
featuresINeed = model(inputs['input_ids'])['pooler_output']

In the case above, I run out of memory. I tried breaking it into chunks and using torch.cuda.empty_cache(), but it doesn't seem to clear all the memory. I tried both with and without GPU. In my case, I am using a dataset of size 60,000 (possibly larger in the future) and using a fine-tuned model of BERT large. I will have a 24GB GPU available for me.

Any suggestions?

To keep in mind, my main goal is to have 1 Language Model predict the next token and extract features of the current sentence.


Solution

  • The code snippet has several issues.

    1. It does not use GPU at all. You need to send both model and the data to a GPU explicitly, e.g., by doing the following:
    device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
    model.to(device)
    inputs.to(device)
    
    1. PyTorch automatically prepares for computing the gradients, which requires storing intermediate results. You can wrap the call in with torch.no_grad() to ensure no gradients are collected.

    2. 60k sentences is too much, regardless of your GPU memory. In any case, you need to split it into more batches. You can use the Dataset interface from Transformers for that.