Search code examples
pythonstdoutpython-multithreading

How to suppress stdout within a specific python thread?


I want to be able to suppress any print to stdout within a specific thread. Here is what I have tried:

import sys, io, time
from threading import Thread

def do_thread_action():

    # Disable stdout
    sys.stdout = io.StringIO()

    print("don't print this 1")
    time.sleep(1)
    print("don't print this 2")
    time.sleep(1)
    print("don't print this 3")

    # Re-enable stdout
    sys.stdout = sys.__stdout__

thread = Thread(target=do_thread_action)


thread.start()

time.sleep(1.5)

# Print this to stdout
print('Print this')

thread.join()

However this does not work because sys.stdout is global for both thread and the main thread.

How do I suppress the prints inside do_thread_action within the thread, but not suppress the prints outside of it?


Solution

  • So, here it is - just replace the sys.stdout object by an object with a write (and to encompass all cases, a flush) methods which can select where the output should go for the currently running thread.

    And they can check the currently running thread using the thread name.

    Here is an almost "production ready" class which can take care of things, including even the decorators for patching the codepaths which should be guarded for printing:

    import time
    import threading
    import io
    from unittest import mock
    import sys
    
    delayed_outputs = None
    
    class SelectOutput():
        def __init__(self, config):
            self.text_io = io.StringIO()
            self.ns = threading.local()
            self.config = config
    
    
        def filter(self, func):
            def wrapper( *args, **kwargs):
                # in this example, thread_id is passed as
                # a parameter, but one could also use
                thread_id = threading.current_thread().name
                if thread_id in self.config and self.config[thread_id] == "capture":
                    self.ns.stdout = self.text_io
                else:
                    self.ns.stdout = sys.__stdout__
    
                return func(*args, **kwargs)
            return wrapper
    
        def instrument(self, func):
            def wrapper(*args, **kwargs):
                #global all_outputs
                with mock.patch("sys.stdout", self):
                    # all_outputs = tmp
                    return func(*args, **kwargs)
            return wrapper
    
        def write(self, text):
            self.ns.stdout.write(text)
    
        def flush(self):
            self.ns.stdout.flush()
    
    select_output = SelectOutput(config={"2": "capture"})
    
    @select_output.filter
    def target(thread_id):
        time.sleep(0.1 * thread_id)
        print(f"At thread ID: {thread_id}")
        time.sleep(0.2 * thread_id)
        print(f"closing thread ID: {thread_id}")
    
    @select_output.instrument
    def instrumented():
        threads = []
        for thread_id in (1, 2, 3):
            thread = threading.Thread(target=target, args=(thread_id,))
            # name the threads so that each is identified when they are running:
            thread.name = str(thread_id)
            threads.append(thread)
            thread.start()
    
    
        [t.join() for t in threads]
    
    def main():
        global delayed_outputs
        delayed_outputs = io.StringIO()
        instrumented()
        print("Delayed outputs:", select_output.text_io.getvalue())
    
    
    
    if __name__ == "__main__":
        main()