Search code examples
djangodatabasemultithreadingcelerythreadpool

close django db connections after ThreadPoolExecutor shutdown


I want to use ThreadPoolExecutor in django and reuse db connection in a thread in order to avoid create db connection for each sub_task. but db connections won't be closed after ThreadPoolExecutor is shutdown. I know that i can close connection at the end of the sub_task. but with this solution, we are creating connection for each task and connection is not reused. there's a initializer params in ThreadPoolExecutor but there isn't something like on_destroy which can be called when thread is destroyed. main_task runs with the celery in my setup.

def sub_task():
    #some db operations


def main_task(max_workers): 
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        for i in range(10):
            executor.submit(sub_task)

Solution

  • I wrote a custom ThreadPoolExecutor and create a list of threads db connections (add pointer to their connection handler in thread initializer function) and close all db connections of all threads on executor shutdown. note that if you want to use your initializer function too, be careful to pass it as keyword argument not positional

    import traceback
    import threading
    from django import db
    from concurrent.futures import ThreadPoolExecutor
    
    
    class DBSafeThreadPoolExecutor(ThreadPoolExecutor):
        
        def generate_initializer(self, initializer):
            def new_initializer(*args, **kwargs):
                self, *args = args
                try:
                    if initializer != None:
                        initializer(*args, **kwargs)
                finally:
                    self.on_thread_init()
            return new_initializer
    
        def on_thread_init(self):
            for curr_conn in db.connections.all():
                curr_conn.connection = None
                self.threads_db_conns.append(curr_conn)
    
        def on_executor_shutdown(self):
            [t.join() for t in self._threads if t != threading.current_thread()]
            for curr_conn in self.threads_db_conns:
                try:
                    curr_conn.inc_thread_sharing()
                    curr_conn.close()
                except Exception:
                    print(f'error while closing connection {curr_conn.alias}')
                    traceback.print_exc()
    
    
        def __init__(self, *args, **kwargs):
            kwargs['initializer'] = self.generate_initializer(kwargs.get('initializer'))
            kwargs['initargs'] = (self,) + (kwargs.get('initargs') or ())
            self.threads_db_conns = []
            super().__init__(*args, **kwargs)
    
        def shutdown(self, *args, **kwargs):
            self.submit(self.on_executor_shutdown)
            super().shutdown(*args, **kwargs)