The warnings.catch_warnings()
context manager is not thread safe. How do I use it in a parallel processing environment?
The code below solves a maximization problem using parallel processing with Python's multiprocessing
module. It takes a list of (immutable) widgets, partitions them up (see Efficient multiprocessing of massive, brute force maximization in Python 3), finds the maxima ("finalists") of all the partitions, and then finds the maximum ("champion") of those "finalists." If I understand my own code correctly (and I wouldn't be here if I did), I'm sharing memory with all the child processes to give them the input widgets, and multiprocessing
uses an operating-system-level pipe and pickling to send the finalist widgets back to the main process when the workers are done.
I want to catch the redundant widget warnings being caused by widgets' re-instantiation after the unpickling that happens when the widgets come out of the inter-process pipe. When widget objects instantiate, they validate their own data, emitting warnings from the Python standard warnings
module to tell the app's user that the widget suspects there is a problem with the user's input data. Because unpickling causes objects to instantiate, my understanding of the code implies that each widget object is reinstantiated exactly once if and only if it is a finalist after it comes out of the pipe -- see the next section to see why this isn't correct.
The widgets were already created before being frobnicated, so the user is already painfully aware of what input he got wrong and doesn't want to hear about it again. These are the warnings I'd like to catch with the warnings
module's catch_warnings()
context manager (i.e., a with
statement).
In my tests I've narrowed down when the superfluous warnings are being emitted to anywhere between what I've labeled below as Line A and Line B. What surprises me is that the warnings are being emitted in places other than just near output_queue.get()
. This implies to me that multiprocessing
sends the widgets to the workers using pickling.
The upshot is that putting a context manager created by warnings.catch_warnings()
even around everything from Line A to Line B and setting the right warnings filter inside this context does not catch the warnings. This implies to me that the warnings are being emitted in the worker processes. Putting this context manager around the worker code does not catch the warnings either.
This example omits the code for deciding if the problem size is too small to bother with forking processes, importing multiprocessing, and defining my_frobnal_counter
, and my_load_balancer
.
"Call `frobnicate(list_of_widgets)` to get the widget with the most frobnals"
def frobnicate_parallel_worker(widgets, output_queue):
resultant_widget = max(widgets, key=my_frobnal_counter)
output_queue.put(resultant_widget)
def frobnicate_parallel(widgets):
output_queue = multiprocessing.Queue()
# partitions: Generator yielding tuples of sets
partitions = my_load_balancer(widgets)
processes = []
# Line A: Possible start of where the warnings are coming from.
for partition in partitions:
p = multiprocessing.Process(
target=frobnicate_parallel_worker,
args=(partition, output_queue))
processes.append(p)
p.start()
finalists = []
for p in processes:
finalists.append(output_queue.get())
# Avoid deadlocks in Unix by draining queue before joining processes
for p in processes:
p.join()
# Line B: Warnings no longer possible after here.
return max(finalists, key=my_frobnal_counter)
Years later, I finally have a solution (found while working on an unrelated problem). I've tested this on Python 3.7, 3.8, and 3.9.
Temporarily patch sys.warnoptions
with the empty list []
. You only need to do this around the call to process.start()
. sys.warnoptions
is documented as an implementation detail that you shouldn't manually modify; the official recommendations are to use functions in the warnings
module and to set PYTHONWARNINGS
in os.environ
. This doesn't work. The only thing that seems to work is patching sys.warnoptions
. In a test, you can do the following:
import multiprocessing
from unittest.mock import patch
p = multiprocessing.Process(target=my_function)
with patch('sys.warnoptions', []):
p.start()
p.join()
If you don't want to use unittest.mock
, just patch by hand:
import multiprocessing
import sys
p = multiprocessing.Process(target=my_function)
old_warnoptions = sys.warnoptions
try:
sys.warnoptions = []
p.start()
finally:
sys.warnoptions = old_warnoptions
p.join()