Search code examples
pythonpython-multithreadingminimax

How can I implement multithreading in this for loop?


Consider this code snippet

from tqdm import trange


def main_game(depth1, depth2):
    # some operator with complexity O(20^max(depth1,depth2))
    return depth1+depth2


DEPTH_MAX = 5
total = 0
for depth1 in range(1, DEPTH_MAX + 1):
    for depth2 in range(1, DEPTH_MAX + 1):
        for i in trange(100):
            total += main_game(depth1, depth2)

print(total)

I'm using minimax algorithm in main_game() with branching factor = 10

Now, since the third for-loop has a time-consuming function (up to 100*O(20^5) in time complexity), is there any way I can make it run faster? I'm thinking of parallelizing (multithreading for example). Any suggestion?


Solution

  • Use multiprocessing, and from there Pool().starmap(). starmap() feeds your function with the prepared tuples of arguments in a parallelized manner. And collects the result synchronously. If the order of the result doesn't matter, you could use the asynchronous version .starmap_async().get().

    There are also Pool().apply() and Pool.map() with their _async() versions, but you actually need just to learn Pool().starmap(). It is only some Syntax difference.

    import multiprocessing as mp
    n_cpu = mp.cpu_count()
    
    # let's say your function is a diadic function (takes two arguments)
    def main_game(depth1, depth2):
        return depth1 + depth2
    
    DEPTH_MAX = 5
    depths = list(range(1, DEPTH_MAX + 1))
    
    # let's pre-prepare the arguments - because that goes fast!
    depth1_depth2_pairs = [(d1, d2) for d1 in depths for d2 in depths]
    
    # 1: Init multiprocessing.Pool()
    pool = mp.Pool(n_cpu)
    # 2: pool.starmap()
    results = pool.starmap(main_game, depth_1_depth_2_pairs)
    # 3: pool.close()
    pool.close()
    
    total = sum(results) # this does your `total +=`
    
    ## in this case, you could even use
    results = pool.starmap_async(main_game, depth_1_depth_2_pairs).get()
    ## because the order doesn't matter, if you sum them all up
    ## which is commutative.
    

    This all you can write slightly more nicer using the with construct (it does the closing automatically, even if an error occurs, so it does not just save you typing but is more secure.

    import multiprocessing as mp
    
    n_cpu = mp.cpu_count()
    
    def main_game(depth1, depth2):
        return depth1 + depth2
    
    DEPTH_MAX = 5
    depths = range(1, DEPTH_MAX + 1)
    depth1_depth2_pairs = [(d1, d2) for d1 in depths for d2 in depths]
    
    with mp.Pool(n_cpu) as pool:
        results = pool.starmap_async(main_game, depth_1_depth_2_pairs).get()
    
    total = sum(results)