Search code examples
pythonmultiprocessingpython-multiprocessing

Python multiprocessing is slower than regular. How can I improve?


Basically have a script that combs a dataset of nodes/points to remove those that overlap. The actual script is more complicated but I pared it down to basically a simple overlap check that does nothing with it for demonstration.

I tried a few variants with locks, queues, pools adding one job at a time versus adding in bulk. Some of the worst offenders were slower by a couple order of magnitude. Eventually I got it to the fastest I could.

The overlap checking algorithm which is send to the individual processes:

def check_overlap(args):
    tolerance = args['tolerance']
    this_coords = args['this_coords']
    that_coords = args['that_coords']

    overlaps = False
    distance_x = this_coords[0] - that_coords[0]
    if distance_x <= tolerance:
        distance_x = pow(distance_x, 2)
        distance_y = this_coords[1] - that_coords[1]
        if distance_y <= tolerance:
            distance = pow(distance_x + pow(distance_y, 2), 0.5)
            if distance <= tolerance:
               overlaps = True

    return overlaps

The processing function:

def process_coords(coords, num_processors=1, tolerance=1):
    import multiprocessing as mp
    import time

    if num_processors > 1:
        pool = mp.Pool(num_processors)
        start = time.time()
        print "Start script w/ multiprocessing"

    else:
        num_processors = 0
        start = time.time()
        print "Start script w/ standard processing"

    total_overlap_count = 0

    # outer loop through nodes
    start_index = 0
    last_index = len(coords) - 1
    while start_index <= last_index:

        # nature of the original problem means we can process all pairs of a single node at once, but not multiple, so batch jobs by outer loop
        batch_jobs = []

        # inner loop against all pairs for this node
        start_index += 1
        count_overlapping = 0
        for i in range(start_index, last_index+1, 1):

            if num_processors:
                # add job
                batch_jobs.append({
                    'tolerance': tolerance,
                    'this_coords': coords[start_index],
                    'that_coords': coords[i]
                })

            else:
                # synchronous processing
                this_coords = coords[start_index]
                that_coords = coords[i]
                distance_x = this_coords[0] - that_coords[0]
                if distance_x <= tolerance:
                    distance_x = pow(distance_x, 2)
                    distance_y = this_coords[1] - that_coords[1]
                    if distance_y <= tolerance:
                        distance = pow(distance_x + pow(distance_y, 2), 0.5)
                        if distance <= tolerance:
                            count_overlapping += 1

        if num_processors:
            res = pool.map_async(check_overlap, batch_jobs)
            results = res.get()
            for r in results:
                if r:
                    count_overlapping += 1

        # stuff normally happens here to process nodes connected to this node
        total_overlap_count += count_overlapping

    print total_overlap_count
    print "  time: {0}".format(time.time() - start)

And testing function:

from random import random

coords = []
num_coords = 1000
spread = 100.0
half_spread = 0.5*spread
for i in range(num_coords):
    coords.append([
        random()*spread-half_spread,
        random()*spread-half_spread
    ])

process_coords(coords, 1)
process_coords(coords, 4)

Still, the non-multiprocessing runs in less than 0.4s consistently and the multiprocessing I can get just under 3.0s as it stands above. I get that maybe the algorithm here is too simple to really reap benefits, but considering the above case has nearly half a million iterations and the real case has significantly more, it's weird to me that the multiprocessing is an order of magnitude slower.

What am I missing / what can I do to improve?


Solution

  • Building O(N**2) 3-element dicts not used in the serialized code, and transmitting them over interprocess pipes, is a pretty good way to guarantee multiprocessing can't help ;-) Nothing comes for free - everything costs.

    Below is a rewrite that executes much the same code regardless of whether it's run in serial or multiprocessing modes. No new dicts, etc. In general, the larger len(coords), the more benefit it gets from multiprocessing. On my box, at 20000 the multiprocessing run takes about a third of the wall-clock time.

    Key to this is that all processes have their own copy of coords. This is done below by transmitting it just once, when the pool is created. That should work on all platforms. On Linux-y systems, it could happen "by magic" instead via forked process inheritance. Reducing the amount of data sent across processes from O(N**2) to O(N) is a huge improvement.

    Getting more out of multiprocessing would require better load balancing. As is, a call to check_overlap(i) compares coords[i] to each value in coords[i+1:]. The larger i, the less work there is for it to do, and for the largest values of i just the cost of transmitting i between processes - and transmitting the result back - swamps the amount of time spent in check_overlap(i).

    def init(*args):
        global _coords, _tolerance
        _coords, _tolerance = args
    
    def check_overlap(start_index):
        coords, tolerance = _coords, _tolerance
        tsq = tolerance ** 2
        overlaps = 0
        start0, start1 = coords[start_index]
        for i in range(start_index + 1, len(coords)):
            that0, that1 = coords[i]
            dx = abs(that0 - start0)
            if dx <= tolerance:
                dy = abs(that1 - start1)
                if dy <= tolerance:
                    if dx**2 + dy**2 <= tsq:
                        overlaps += 1
        return overlaps
    
    def process_coords(coords, num_processors=1, tolerance=1):
        global _coords, _tolerance
        import multiprocessing as mp
        _coords, _tolerance = coords, tolerance
        import time
    
        if num_processors > 1:
            pool = mp.Pool(num_processors, initializer=init, initargs=(coords, tolerance))
            start = time.time()
            print("Start script w/ multiprocessing")
        else:
            num_processors = 0
            start = time.time()
            print("Start script w/ standard processing")
    
        N = len(coords)
        if num_processors:
            total_overlap_count = sum(pool.imap_unordered(check_overlap, range(N))) 
        else:
            total_overlap_count = sum(check_overlap(i) for i in range(N))
    
        print(total_overlap_count)
        print("  time: {0}".format(time.time() - start))
    
    if __name__ == "__main__":
        from random import random
    
        coords = []
        num_coords = 20000
        spread = 100.0
        half_spread = 0.5*spread
        for i in range(num_coords):
            coords.append([
                random()*spread-half_spread,
                random()*spread-half_spread
            ])
    
        process_coords(coords, 1)
        process_coords(coords, 4)