Search code examples
pythonpytorchmultiprocessinghuggingface-transformers

multiprocessing with HF's transformers uses all CPU cores despite being limited num_workers


I am using torch.multiprocessing to parallelize my for-loop as follows:

import torch
import torch.multiprocessing as mp

torch.set_num_threads(1)
mp.set_start_method("spawn", force=True)

seqs = # List of strings

def func1(x):
    # Some numpy calculation
    return a, b

def func2(a, b):
    # Run initial computations with torch
    # Run forward inference of nn.Module (wrapper of transformers model)
    # Perform additional computations with torch
    return c

def main_func(x):
   a, b = func1(x)
   c = func2(a, b)
   return c


pool = mp.Pool(processes=10)
results = pool.map(main_func, seqs)

In this code, func1 performs some numpy calculations, while func2 represents the forward call of an nn.Module, which wraps the EsmForMaskedLM model.

Specifically, the mentioned nn.Module can be simplified as below

import torch
from transformers import AutoTokenizer, EsmForMaskedLM, BatchEncoding


class ESM2(nn.Module):
    def __init__(self):
        super(ESM2, self).__init__()
        self.model = EsmForMaskedLM.from_pretrained("facebook/esm2_t12_35M_UR50D")
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
        self.model.to("cpu")

    def forward(self, inputs: BatchEncoding) -> torch.Tensor:
        results = self.model(**inputs)   # inputs has been tokenized and passed to CPU
        return results

When I run the code sequentially in a for-loop, everything works as expected. However, when I attempt to parallelize it using torch.multiprocessing, it seems to occupy all CPU cores in my machine, even though I have set a limit on the number of processes. Upon debugging, I found that this issue occurs only when calling func2 (i.e., the forward function). I suspect there might be a problem with the forward function, but I'm uncertain. Can someone please assist me with this issue? Thank you very much for your help!

P/s: If you guys need any additional information, please let me know, I am happy to provide more


Solution

  • Apparently, PyTorch uses a parallelization library called OpenMP (link); and according to this answer:

    OpenMP does multi-threading within a process, and the default number of threads is typically the number that the CPU can actually run simultaneously
    ...
    So, what happens on that quad-core CPU if you run a multiprocessing program that runs 4 Python processes, and each calls an OpenMP function runs 4 threads? You end up running 16 threads on 4 cores

    So when I set os.environ["OPENMP_NUM_THREADS"] = "1", it performs as expected.

    Other resources: HF discussion