Search code examples
pythonmultithreadingjoblib

How to thread a generator


I have a generator object, that loads quite big amount of data and hogs the I/O of the system. The data is too big to fit into memory all at once, hence the use of generator. And I have a consumer that all of the CPU to process the data yielded by generator. It does not consume much of other resources. Is it possible to interleave these tasks using threads?

For example I'd guess it is possible to run the simplified code below in 11 seconds.

import time, threading
lock = threading.Lock()
def gen():
    for x in range(10):
        time.sleep(1)
        yield x
def con(x):
    lock.acquire()
    time.sleep(1)
    lock.release()
    return x+1

However, the simplest application of threads does not run in that time. It does speed up, but I assume because of parallelism between the dispatcher which does generation and the worked. But not thanks to parallelism between workers.

import joblib
%time joblib.Parallel(n_jobs=2,backend='threading',pre_dispatch=2)((joblib.delayed(con)(x) for x in gen()))
# CPU times: user 0 ns, sys: 0 ns, total: 0 ns
# Wall time: 16 s

Solution

  • Send your data to separate processes. I used concurrent.futures because I like the simple interface.

    This runs in about 11 seconds on my computer.

    from concurrent.futures import ThreadPoolExecutor
    import concurrent
    import threading
    lock = threading.Lock()
    
    def gen():
        for x in range(10):
            time.sleep(1)
            yield x
    
    def con(x):
        lock.acquire()
        time.sleep(1)
        lock.release()
        return f'{x+1}'
    
    if __name__ == "__main__":
    
        futures = []
        with ThreadPoolExecutor() as executor:
            t0 = time.time()
            for x in gen():
                futures.append(executor.submit(con,x))
        results = []
        for future in concurrent.futures.as_completed(futures):
            results.append(future.result())
        print(time.time() - t0)
        print('\n'.join(results))
    

    Using 100 generator iterations (def gen(): for x in range(100):) it took about 102 seconds.


    Your process may need to keep track of how much data has been sent to tasks that haven't finished to prevent swamping memory resources.

    Adding some diagnostic prints to con seems to show that there might be at least two chunks of data out there at a time.

    def con(x):
        print(f'{x} received payload at t0 + {time.time()-t0:3.3f}')
        lock.acquire()
        time.sleep(1)
        lock.release()
        print(f'{x} released lock at t0 + {time.time()-t0:3.3f}')
        return f'{x+1}'