I am attempting to run a multiprocessed job within a larger Python class. In a simple form, the class looks as following:
class Thing:
def test(self):
with mp.Pool() as p:
yield from p.map(str, range(20))
When I import this class to a script such as:
from x import Thing
t = Thing()
for item in t.test():
print(item)
I run into the commonly known issue:
RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.
This probably means that you are not using fork to start your
child processes and you have forgotten to use the proper idiom
in the main module:
if __name__ == '__main__':
freeze_support()
...
The "freeze_support()" line can be omitted if the program
is not going to be frozen to produce an executable.
My question is, how do I guard against this behaviour, without requiring the user of my class to write
if __name__ == "__main__":
every time they run my class function? Is there a way of performing this inside the class definition?
I have tried writing
from x import Thing
t = Thing()
if __name__ == "__main__":
for item in t.test():
print(item)
which solves the issue, but I do not want this to be the way a user interfaces this class.
Since the value of __name__
in the main module of a spawned child process is '__mp_main__'
as discussed here, you can place a guard in the function that spawns child processes by checking if any of the ancestors' frames has a __name__
in the global namespace with a value of '__mp_main__'
, in which case the current process is a child and the execution should stop so that no more child would be spawned:
# x.py
import sys
import multiprocessing as mp
class Thing:
def test(self):
frame = sys._getframe(1)
while frame:
if frame.f_globals['__name__'] == '__mp_main__':
return
frame = frame.f_back
with mp.Pool() as p:
yield from p.map(str, range(20))
EDIT: An even easier way to guard against a child process is to validate that the current process name is 'MainProcess'
:
import multiprocessing as mp
class Thing:
def test(self):
if mp.current_process().name == 'MainProcess':
with mp.Pool() as p:
yield from p.map(str, range(20))