Search code examples
pythonclassmultiprocessing

Multiprocessing pool in a Python class without __name__ == "__main__" guard


I am attempting to run a multiprocessed job within a larger Python class. In a simple form, the class looks as following:

class Thing:
    def test(self):
        with mp.Pool() as p:
            yield from p.map(str, range(20))

When I import this class to a script such as:

from x import Thing

t = Thing()
for item in t.test():
    print(item)

I run into the commonly known issue:

RuntimeError:
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

My question is, how do I guard against this behaviour, without requiring the user of my class to write

if __name__ == "__main__":

every time they run my class function? Is there a way of performing this inside the class definition?

I have tried writing

from x import Thing

t = Thing()
if __name__ == "__main__":
    for item in t.test():
        print(item)

which solves the issue, but I do not want this to be the way a user interfaces this class.


Solution

  • Since the value of __name__ in the main module of a spawned child process is '__mp_main__' as discussed here, you can place a guard in the function that spawns child processes by checking if any of the ancestors' frames has a __name__ in the global namespace with a value of '__mp_main__', in which case the current process is a child and the execution should stop so that no more child would be spawned:

    # x.py
    import sys
    import multiprocessing as mp
    
    class Thing:
        def test(self):
            frame = sys._getframe(1)
            while frame:
                if frame.f_globals['__name__'] == '__mp_main__':
                    return
                frame = frame.f_back
            with mp.Pool() as p:
                yield from p.map(str, range(20))
    

    EDIT: An even easier way to guard against a child process is to validate that the current process name is 'MainProcess':

    import multiprocessing as mp
    
    class Thing:
        def test(self):
            if mp.current_process().name == 'MainProcess':
                with mp.Pool() as p:
                    yield from p.map(str, range(20))