Search code examples
pythonmultithreadingnetworkingsignals

Shutting a ThreadPoolExecutor down with a KeyboardInterrupt


I'm using Python's ThreadPoolExecutor to implement an asynchronous client. Every so often, the script submits a synchronous client callable to the thread pool. I'd like to be able to stop the loop with a KeyboardInterrupt.

I've written the following code.

#!/usr/bin/env python

import numpy as np
import tritonclient.http as tritonclient

import argparse
import itertools
import logging
import random
import sys
import time
from concurrent.futures import ThreadPoolExecutor

distributions = {
    'poisson': lambda w: random.expovariate(1/w),
    'uniform': lambda w: random.uniform(0, 2*w),
}

class Client:
    def __init__(self, url, model):
        self.client = tritonclient.InferenceServerClient(url)
        config = self.client.get_model_config(model)
        self.inputs = config['input']
        self.outputs = [output['name'] for output in config['output']]
        self.model = model

    def __call__(self):
        inputs = []
        for config in self.inputs:
            assert config['data_type'] == 'TYPE_FP32'
            shape = [1] + config['dims']
            datatype = config['data_type'].removeprefix('TYPE_')
            input = tritonclient.InferInput(config['name'], shape, datatype)
            array = np.random.default_rng().random(shape, dtype=np.float32)
            input.set_data_from_numpy(array)
            inputs.append(input)
        result = self.client.infer(self.model, inputs)
        for output in self.outputs:
            result.get_output(output)

def benchmark(fn):
    t_i = time.time()
    fn()
    t_f = time.time()
    print(t_i, t_f - t_i)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--distribution', choices=distributions.values(),
                        type=distributions.get, default=lambda w: w)
    parser.add_argument('-n', '--nrequests', default=-1, type=int)
    parser.add_argument('-o', '--open', '--open-loop', action='store_true')
    parser.add_argument('-u', '--url', default='localhost:8000')
    parser.add_argument('-v', '--verbose', action='count', default=0)
    parser.add_argument('model')
    rate = parser.add_mutually_exclusive_group()
    rate.add_argument('-w', '--wait', '--delay', '-l', '--lambda',
                      default=0, type=float)
    rate.add_argument('-r', '--rate', '-f', '--frequency', type=float)
    args = parser.parse_args()

    level = (logging.DEBUG if args.verbose > 1
            else logging.INFO if args.verbose
            else logging.WARNING)
    logging.basicConfig(level=level)

    if args.rate:
        args.wait = 1/args.rate
    logging.debug(args)

    client = Client(args.url, args.model)

    with ThreadPoolExecutor() as executor:
        try:
            for _ in (itertools.count() if args.nrequests < 0
                        else range(args.nrequests)):
                if args.open:
                    executor.submit(benchmark, client)
                else:
                    benchmark(client)
                time.sleep(args.distribution(args.wait))
        except KeyboardInterrupt:
            pass
        except BrokenPipeError:
            pass

It hangs on the first Control-C and requires two more before it finally exits with the following error. For what it's worth, my first use of ChatGPT was today in order to figure this out. It didn't work. What a let-down.

1717617460.23863 0.003475189208984375
1717617460.250774 0.0033867359161376953
1717617460.2690861 0.0033500194549560547
^C^CTraceback (most recent call last):
  File "/data/pcoppock/mlos/apps/tritonclient", line 73, in <module>
    with ThreadPoolExecutor() as executor:
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 649, in __exit__
    self.shutdown(wait=True)
  File "/usr/lib/python3.10/concurrent/futures/thread.py", line 235, in shutdown
    t.join()
  File "/usr/lib/python3.10/threading.py", line 1096, in join
    self._wait_for_tstate_lock()
  File "/usr/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock
    if lock.acquire(block, timeout):
KeyboardInterrupt
^CException ignored in: <module 'threading' from '/usr/lib/python3.10/threading.py'>
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1537, in _shutdown
    atexit_call()
  File "/usr/lib/python3.10/concurrent/futures/thread.py", line 31, in _python_exit
    t.join()
  File "/usr/lib/python3.10/threading.py", line 1096, in join
    self._wait_for_tstate_lock()
  File "/usr/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock
    if lock.acquire(block, timeout):
KeyboardInterrupt: 

linux$ 

Solution

  • The issue lay with the Triton client module. The Client objects weren't thread-safe. After I modified the code to initialize a new client in every thread, a single ctrl-C eventually stopped the program.