Search code examples
pytorchparallel-processingout-of-memorybatch-processinghuggingface

Parallelize inference with huggingface using torch


I am running an inference model on a Ubuntu machine with 8GB only and just realised the predictions (logits) are not generated in a batch way so my process is getting Killed due oom issues.

tokenized_test = tokenizer(dataset["test"]["text"], padding=True, truncation=True, return_tensors="pt")

with torch.no_grad():
    logits = model(**tokenized_test).logits

This is where I run out of memory. What is the best way of do this in batches/parallelize/sequentiate/solve the oom issue. I am ultimately looking for the solution that would require the least amount of code changes.


Source

I have built my code based on this tutorial:

https://huggingface.co/docs/transformers/tasks/sequence_classification

Increasing the dataset size will eventually make you go oom to.


Solution

  • try converting test set to TensorDataset and then use DataLoader. somthing like this:

    from torch.utils.data import DataLoader, TensorDataset
    
    batch_size = 32
    
    test_dataset = TensorDataset(*[tokenized_test[key] for key in tokenized_test])
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
    
    with torch.no_grad():
        logits_list = []
        for batch in test_dataloader:
            batch_logits = model(*batch).logits
            logits_list.append(batch_logits)
    
    logits = torch.cat(logits_list)