This is my first attempt with multiprocessing using Python multiprocessing library. The simple version of the code is like below -
import multiprocessing as mp
from dataclasses import dataclass
from typing import Dict, NoReturn
import time
import logging
import signal
import numpy as np
@dataclass
class TmpData:
name: str
value: int
def worker(name: str, data: TmpData) -> NoReturn:
logger_obj = mp.log_to_stderr()
logger_obj.setLevel(logging.INFO)
logger_obj.info(f"name: {name}; value: {data.value}")
if name == "XYZ":
raise RuntimeError("XYZ worker failed")
time.sleep(data.value)
def init_worker_processes() -> None:
signal.signal(signal.SIGINT, signal.SIG_IGN)
if __name__ == "__main__":
map_data: Dict[str, TmpData] = {
key: TmpData(name=key, value=np.random.randint(5, 15))
for key in ["ABC", "DEF", "XYZ"]
}
main_logger = logging.getLogger()
with mp.get_context("spawn").Pool(
processes=2,
initializer=init_worker_processes(),
) as pool:
results = []
for key in map_data:
try:
results.append(
pool.apply_async(
worker,
args=(
key,
map_data[key],
),
)
)
except KeyboardInterrupt:
pool.terminate()
pool.close()
pool.join()
for result in results:
try:
result.get()
except Exception as err:
main_logger.error(f"{err}")
This outputs something like following -
[INFO/SpawnPoolWorker-2] name: ABC; value: 10
[INFO/SpawnPoolWorker-1] name: DEF; value: 10
[INFO/SpawnPoolWorker-2] name: XYZ; value: 12
[INFO/SpawnPoolWorker-2] name: XYZ; value: 12
[INFO/SpawnPoolWorker-2] process shutting down
[INFO/SpawnPoolWorker-2] process shutting down
[INFO/SpawnPoolWorker-2] process exiting with exitcode 0
[INFO/SpawnPoolWorker-1] process shutting down
[INFO/SpawnPoolWorker-2] process exiting with exitcode 0
[INFO/SpawnPoolWorker-1] process exiting with exitcode 0
XYZ worker failed
What I am concerned is [INFO/SpawnPoolWorker-2] name: XYZ; value: 12
printed twice. I guess it is only printing issue (not 2 processes spawned as the error message of XYZ worker failing is coming only once). The issue does not occur when pool is initialized with 3 processes.
Now, I want to understand what the root cause is and how to fix it. Can someone help me understand what I may be doing wrong and how to fix it?
The issue is how you add an stderr logger. When you call mp.log_to_stderr
the thread doesn't delete existing logger handlers, but adds additional handler that streams to stdout. In another words every time you run the def worker(...)
you will add additional logging.StreamHandler
to existing handlers in the thread logger.
Step by step:
logging.StreamHandler
logging.StreamHandler
logging.StreamHandler
. Now it will print all logs twice.To print existing logger handlers you can use the following snippet:
def worker(name: str, data: TmpData) -> NoReturn:
_ = mp.log_to_stderr()
# print existing logger handlers:
logger = mp.get_logger()
thread_name = mp.current_process().name
print (thread_name, logger.handlers)
This will output:
SpawnPoolWorker-1 [<StreamHandler <stderr> (NOTSET)>]
SpawnPoolWorker-2 [<StreamHandler <stderr> (NOTSET)>]
SpawnPoolWorker-1 [<StreamHandler <stderr> (NOTSET)>, <StreamHandler <stderr> (NOTSET)>]
As you may see - second thread has two StreamHandlers. Thus it will print every text twice (once with each handler).
Solution:
The correct way to add new loggers is to do it in the init_worker_processes
.
def worker(name: str, data: TmpData) -> NoReturn:
# get existing logger which already has a stdout StreamHandler
logger = mp.get_logger()
logger.info(f"name: {name}; value: {data.value}")
if name == "XYZ":
raise RuntimeError("XYZ worker failed")
time.sleep(0.01)
def init_worker_processes() -> None:
# this only runs single time per each thread
logger = mp.log_to_stderr()
logger.setLevel(logging.INFO)
signal.signal(signal.SIGINT, signal.SIG_IGN)
Hope this helps.
PS: Also you will need to fix the lines related to pool initializer. Correct way is to pass init_worker_processes
function directly like this:
with mp.get_context("spawn").Pool(
processes=2,
initializer=init_worker_processes,
) as pool: