Search code examples
pythonmultiprocessingpython-multiprocessingtqdmprocess-pool

Starmap combined with tqdm?


I am doing some parallel processing, as follows:

with mp.Pool(8) as tmpPool:
        results = tmpPool.starmap(my_function, inputs)

where inputs look like: [(1,0.2312),(5,0.52) ...] i.e., tuples of an int and a float.

The code runs nicely, yet I cannot seem to wrap it around a loading bar (tqdm), such as can be done with e.g., imap method as follows:

tqdm.tqdm(mp.imap(some_function,some_inputs))

Can this be done for starmap also?

Thanks!


Solution

  • It's not possible with starmap(), but it's possible with a patch adding Pool.istarmap(). It's based on the code for imap(). All you have to do, is create the istarmap.py-file and import the module to apply the patch before you make your regular multiprocessing-imports.

    Python <3.8

    # istarmap.py for Python <3.8
    import multiprocessing.pool as mpp
    
    
    def istarmap(self, func, iterable, chunksize=1):
        """starmap-version of imap
        """
        if self._state != mpp.RUN:
            raise ValueError("Pool not running")
    
        if chunksize < 1:
            raise ValueError(
                "Chunksize must be 1+, not {0:n}".format(
                    chunksize))
    
        task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
        result = mpp.IMapIterator(self._cache)
        self._taskqueue.put(
            (
                self._guarded_task_generation(result._job,
                                              mpp.starmapstar,
                                              task_batches),
                result._set_length
            ))
        return (item for chunk in result for item in chunk)
    
    
    mpp.Pool.istarmap = istarmap
    

    Python 3.8+

    # istarmap.py for Python 3.8+
    import multiprocessing.pool as mpp
    
    
    def istarmap(self, func, iterable, chunksize=1):
        """starmap-version of imap
        """
        self._check_running()
        if chunksize < 1:
            raise ValueError(
                "Chunksize must be 1+, not {0:n}".format(
                    chunksize))
    
        task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
        result = mpp.IMapIterator(self)
        self._taskqueue.put(
            (
                self._guarded_task_generation(result._job,
                                              mpp.starmapstar,
                                              task_batches),
                result._set_length
            ))
        return (item for chunk in result for item in chunk)
    
    
    mpp.Pool.istarmap = istarmap
    

    Then in your script:

    import istarmap  # import to apply patch
    from multiprocessing import Pool
    import tqdm    
    
    
    def foo(a, b):
        for _ in range(int(50e6)):
            pass
        return a, b    
    
    
    if __name__ == '__main__':
    
        with Pool(4) as pool:
            iterable = [(i, 'x') for i in range(10)]
            for _ in tqdm.tqdm(pool.istarmap(foo, iterable),
                               total=len(iterable)):
                pass