Search code examples
pythontestingmultiprocessingpytestconcurrent.futures

Pytest monkeypatch a multiprocess function for testing


Pytest monkeypatching a function that uses multiprocess (via concurrent.futures.ProcessPoolExecutor) does not work as expected. If the same function is written with a single process or multithread (via concurrent.futures.ThreadPoolExecutor), monkeypatch works as expected.

Why does multiprocess monkeypatch fail and how does one correctly monkeypatch a multiprocess function for testing?

Simplest code example to illustrate my question is below. In actual usage, I would try to monkeypatch a function imported from another module.

# file_a.py

import concurrent.futures as ccf

MY_CONSTANT = "hello"

def my_function():
    return MY_CONSTANT


def singleprocess_f():
    result = []
    for _ in range(3):
        result.append(my_function())
    return result


def multithread_f():
    result = []
    with ccf.ThreadPoolExecutor() as executor:
        futures = []
        for _ in range(3):
            future = executor.submit(my_function)
            futures.append(future)
        for future in ccf.as_completed(futures):
            result.append(future.result())
    return result


def multiprocess_f():
    result = []
    with ccf.ProcessPoolExecutor() as executor:
        futures = []
        for _ in range(3):
            future = executor.submit(my_function)
            futures.append(future)
        for future in ccf.as_completed(futures):
            result.append(future.result())
    return result

I expected all tests to pass (pip install -U pytest, run tests with pytest test_file_a.py):

# test_file_a.py

from file_a import multiprocess_f, multithread_f, singleprocess_f

# PASSES:
def test_singleprocess_f(monkeypatch):
    monkeypatch.setattr("file_a.MY_CONSTANT", "world")
    result = singleprocess_f()
    assert result == ["world"] * 3

# PASSES:
def test_multithread_f(monkeypatch):
    monkeypatch.setattr("file_a.MY_CONSTANT", "world")
    result = multithread_f()
    assert result == ["world"] * 3

# FAILS:
def test_multiprocess_f(monkeypatch):
    monkeypatch.setattr("file_a.MY_CONSTANT", "world")
    result = multiprocess_f()
    assert result == ["world"] * 3

Solution

  • My guess is that you are running on some platform that uses by default the spawn method for creating new processes (e.g. Windows). If that is the case, it does not matter what value you assign to MY_CONSTANT in the main process, using monkeypatch or otherwise, because when the child pool processes are created they will initialize its memory by re-executing the statement MY_CONSTANT = "hello". That is, child processes created with spawn do not inherit any values from the parent process but instead initialize memory by executing all statements at global scope (e.g. import statement, function and variable declarations, etc.).

    If this is the case, then you must therefore provide a pool initializer function using the initializer argument when creating your ProcessPoolExecutor instance. Function multiprocessing_f needs to participate in its testing by accepting an optional executor argument (a multiprocessing pool) to be used for testing:

    File file_a.py

    import concurrent.futures as ccf
    
    
    MY_CONSTANT = "hello"
    
    
    def my_function():
        return MY_CONSTANT
    
    
    def singleprocess_f():
        result = []
        for _ in range(3):
            result.append(my_function())
        return result
    
    
    def multithread_f():
        result = []
        with ccf.ThreadPoolExecutor() as executor:
            futures = []
            for _ in range(3):
                future = executor.submit(my_function)
                futures.append(future)
            for future in ccf.as_completed(futures):
                result.append(future.result())
        return result
    
    
    def multiprocess_f(executor=None):
        """To allow unit testing, we may be passed an executor
        to use."""
    
        if executor is None:
            executor = ccf.ProcessPoolExecutor()
    
        result = []
        with executor:
            futures = []
            for _ in range(3):
                future = executor.submit(my_function)
                futures.append(future)
            for future in ccf.as_completed(futures):
                result.append(future.result())
        return result
    

    File test_file_a.py

    import concurrent.futures as ccf
    
    import file_a
    
    def init_pool_processes():
        file_a.MY_CONSTANT = "world"
    
    
    # PASSES:
    def test_singleprocess_f(monkeypatch):
        with monkeypatch.context() as m:
            monkeypatch.setattr(file_a, "MY_CONSTANT", "world")
            result = file_a.singleprocess_f()
            assert result == ["world"] * 3
    
    # PASSES:
    def test_multithread_f(monkeypatch):
        monkeypatch.setattr(file_a, "MY_CONSTANT", "world")
        result = file_a.multithread_f()
        assert result == ["world"] * 3
    
    # PASSES:
    def test_multiprocess_f():
        # Special executor for testing:
        executor = ccf.ProcessPoolExecutor(initializer=init_pool_processes)
        result = file_a.multiprocess_f(executor)
        assert result == ["world"] * 3