Search code examples
pythonpython-3.xmultiprocessing

Redundant print with multiprocessing


This is my first attempt with multiprocessing using Python multiprocessing library. The simple version of the code is like below -

import multiprocessing as mp
from dataclasses import dataclass
from typing import Dict, NoReturn
import time
import logging
import signal

import numpy as np


@dataclass
class TmpData:
    name: str
    value: int


def worker(name: str, data: TmpData) -> NoReturn:
    logger_obj = mp.log_to_stderr()
    logger_obj.setLevel(logging.INFO)
    logger_obj.info(f"name: {name}; value: {data.value}")

    if name == "XYZ":
        raise RuntimeError("XYZ worker failed")

    time.sleep(data.value)


def init_worker_processes() -> None:
    signal.signal(signal.SIGINT, signal.SIG_IGN)


if __name__ == "__main__":
    map_data: Dict[str, TmpData] = {
        key: TmpData(name=key, value=np.random.randint(5, 15))
        for key in ["ABC", "DEF", "XYZ"]
    }

    main_logger = logging.getLogger()
    with mp.get_context("spawn").Pool(
        processes=2,
        initializer=init_worker_processes(),
    ) as pool:
        results = []
        for key in map_data:
            try:
                results.append(
                    pool.apply_async(
                        worker,
                        args=(
                            key,
                            map_data[key],
                        ),
                    )
                )
            except KeyboardInterrupt:
                pool.terminate()

        pool.close()
        pool.join()

        for result in results:
            try:
                result.get()
            except Exception as err:
                main_logger.error(f"{err}")

This outputs something like following -

[INFO/SpawnPoolWorker-2] name: ABC; value: 10
[INFO/SpawnPoolWorker-1] name: DEF; value: 10
[INFO/SpawnPoolWorker-2] name: XYZ; value: 12
[INFO/SpawnPoolWorker-2] name: XYZ; value: 12
[INFO/SpawnPoolWorker-2] process shutting down
[INFO/SpawnPoolWorker-2] process shutting down
[INFO/SpawnPoolWorker-2] process exiting with exitcode 0
[INFO/SpawnPoolWorker-1] process shutting down
[INFO/SpawnPoolWorker-2] process exiting with exitcode 0
[INFO/SpawnPoolWorker-1] process exiting with exitcode 0
XYZ worker failed

What I am concerned is [INFO/SpawnPoolWorker-2] name: XYZ; value: 12 printed twice. I guess it is only printing issue (not 2 processes spawned as the error message of XYZ worker failing is coming only once). The issue does not occur when pool is initialized with 3 processes.

Now, I want to understand what the root cause is and how to fix it. Can someone help me understand what I may be doing wrong and how to fix it?


Solution

  • The issue is how you add an stderr logger. When you call mp.log_to_stderr the thread doesn't delete existing logger handlers, but adds additional handler that streams to stdout. In another words every time you run the def worker(...) you will add additional logging.StreamHandler to existing handlers in the thread logger.

    Step by step:

    1. Thread 1 takes the first job and creates it's first logging.StreamHandler
    2. Thread 2 takes second job and creates it's first logging.StreamHandler
    3. Thread 1 finishes the first job (as an example)
    4. Thread 1 takes third job and creates it's second logging.StreamHandler. Now it will print all logs twice.

    To print existing logger handlers you can use the following snippet:

    def worker(name: str, data: TmpData) -> NoReturn:    
        _ = mp.log_to_stderr()
    
        # print existing logger handlers:
        logger = mp.get_logger()    
        thread_name = mp.current_process().name
        print (thread_name, logger.handlers)
    

    This will output:

    SpawnPoolWorker-1 [<StreamHandler <stderr> (NOTSET)>]
    SpawnPoolWorker-2 [<StreamHandler <stderr> (NOTSET)>]
    SpawnPoolWorker-1 [<StreamHandler <stderr> (NOTSET)>, <StreamHandler <stderr> (NOTSET)>]
    

    As you may see - second thread has two StreamHandlers. Thus it will print every text twice (once with each handler).

    Solution:

    The correct way to add new loggers is to do it in the init_worker_processes.

    def worker(name: str, data: TmpData) -> NoReturn:
        # get existing logger which already has a stdout StreamHandler
        logger = mp.get_logger() 
        logger.info(f"name: {name}; value: {data.value}")
    
        if name == "XYZ":
            raise RuntimeError("XYZ worker failed")
    
        time.sleep(0.01)
    
    
    def init_worker_processes() -> None:
        # this only runs single time per each thread
        logger = mp.log_to_stderr() 
        logger.setLevel(logging.INFO)
    
        signal.signal(signal.SIGINT, signal.SIG_IGN)
    
    

    Hope this helps.

    PS: Also you will need to fix the lines related to pool initializer. Correct way is to pass init_worker_processes function directly like this:

        with mp.get_context("spawn").Pool(
            processes=2,
            initializer=init_worker_processes,
        ) as pool: