Search code examples
pythonasynchronousmultiprocessingbreadth-first-searchconcurrent.futures

ThreadPoolExecutor multiprocessing with while loop and breadth first search?


I'm trying to speed up some API calls by using ThreadPoolExecutor. I have a class that accepts a string list of h3 cells like cell1,cell2. h3 uses hexagons at different resolutions to get finer detail in mapping. The class methods take the returned cells and gets information about them that is passed to an API with params. The API will return a total number of results (could be over 1000). Because the API is limited to returning at most the first 1000 results through pagination, I utilize h3 to zoom into each cell until all of its children/grandchildren/etc have a total number of results under 1000. This is effectively doing BFS from the original cells provided.

When running this code with the run method, the expectation is that the search_queue would be empty as all cells have been processed. However, with the way its set up currently, only the origin_cells provided to the class get processed and retrieving search_queue shows unprocessed items. Swapping the while and ThreadPoolExecutor lines does run everything as expected, but it runs at the same speed as without using ThreadPoolExecutor.

Is there a way to make the multiprocessing work as expected?

Edit with working example

import h3
import math
import requests
from concurrent.futures import ThreadPoolExecutor
from time import sleep

dummy_results = {
    '85489e37fffffff': {'total': 1001},
    '85489e27fffffff': {'total': 999},
    '86489e347ffffff': {'total': 143},
    '86489e34fffffff': {'total': 143},
    '86489e357ffffff': {'total': 143},
    '86489e35fffffff': {'total': 143},
    '86489e367ffffff': {'total': 143},
    '86489e36fffffff': {'total': 143},
    '86489e377ffffff': {'total': 143},
}

class SearchH3Test(object):

    def __init__(self, origin_cells):
        self.search_queue = list(filter(None, origin_cells.split(',')))
        self.params_list = []
        
    def get_h3_radius(self, cell, buffer=False):
        """
        Get the approximate radius of the h3 cell
        """
        return math.ceil(
            math.sqrt(
                (h3.cell_area(cell))/(1.5*math.sqrt(3))
            )*1000
            + ((100*(h3.h3_get_resolution(cell)/10)) if buffer else 0)
        )
    
    def get_items(self, cell):
        """
        Return API items from passed params, including total number of items and a dict of items

        r = requests.get(
            url = 'https://someapi.com',
            headers = api_headers,
            params = params
        ).json()
        """
        sleep(1)

        r = dummy_results[cell]

        return r['total']
    
    def get_hex_params(self, cell):
        """
        Return results from the derived params of the h3 cell
        """
        lat, long = h3.h3_to_geo(cell)
        radius = self.get_h3_radius(cell, buffer=True)

        params = {
            'latitude': lat,
            'longitude': long,
            'radius': radius,
        }

        total = self.get_items(cell)
        print(total)

        return total, params
    
    def hex_search(self):
        """
        Checks if the popped h3 cell produces a total value over 1000.
        If over 1000, get the h3 cell children and append them to the search_queue
        If greater than 0, append params to params_list
        """
        cell = self.search_queue.pop(0)
        total, params = self.get_hex_params(cell)
        if total > 1000:
            self.search_queue.extend(list(h3.h3_to_children(cell)))
        elif total > 0:
            self.params_list.append(params)
    
    def get_params_list(self):
        """
        Keep looping through the search quque until no items remain.
        Use multiprocessing to speed up things
        """
        with ThreadPoolExecutor() as e:
            while self.search_queue:
                e.submit(self.hex_search)
    
    def run(self):
        self.get_params_list()
h = SearchH3Test(
    '85489e37fffffff,85489e27fffffff',
)

h.run()
len(h.search_queue) # returns 7 for the children that weren't processed as expected
len(h.params_list) # returns 1 for the cell under 1000

Solution

  • When dealing with multiple threads/process, you must use a queue data structure, which is thread safe. Here is my rewrite of your code:

    import math
    import queue
    from concurrent.futures import ThreadPoolExecutor
    
    import h3
    
    dummy_results = {
        "85489e37fffffff": {"total": 1001},
        "85489e27fffffff": {"total": 999},
        "86489e347ffffff": {"total": 143},
        "86489e34fffffff": {"total": 143},
        "86489e357ffffff": {"total": 143},
        "86489e35fffffff": {"total": 143},
        "86489e367ffffff": {"total": 143},
        "86489e36fffffff": {"total": 143},
        "86489e377ffffff": {"total": 143},
    }
    
    
    class SearchH3Test:
        def __init__(self, origin_cells):
            self.search_queue = queue.Queue()
            for cell in origin_cells.split(","):
                if cell:
                    self.search_queue.put(cell)
            self.params_list = queue.Queue()
    
        def get_h3_radius(self, cell, buffer=False):
            """
            Get the approximate radius of the h3 cell
            """
            return math.ceil(
                math.sqrt((h3.cell_area(cell)) / (1.5 * math.sqrt(3))) * 1000
                + ((100 * (h3.h3_get_resolution(cell) / 10)) if buffer else 0)
            )
    
        def get_items(self, cell):
            """
            Return API items from passed params, including total number of items and a dict of items
    
            r = requests.get(
                url = 'https://someapi.com',
                headers = api_headers,
                params = params
            ).json()
            """
            result = dummy_results[cell]
            return result["total"]
    
        def get_hex_params(self, cell):
            """
            Return results from the derived params of the h3 cell
            """
            lat, long = h3.h3_to_geo(cell)
            radius = self.get_h3_radius(cell, buffer=True)
    
            params = {
                "latitude": lat,
                "longitude": long,
                "radius": radius,
            }
    
            total = self.get_items(cell)
            print(f"{total=}")
    
            return total, params
    
        def hex_search(self):
            """
            Checks if the popped h3 cell produces a total value over 1000.
            If over 1000, get the h3 cell children and append them to the search_queue
            If greater than 0, append params to params_list
            """
            try:
                cell = self.search_queue.get_nowait()
            except queue.Empty:
                return
    
            total, params = self.get_hex_params(cell)
            if total > 1000:
                for children in h3.h3_to_children(cell):
                    self.search_queue.put(children)
            elif total > 0:
                self.params_list.put(params)
    
        def get_params_list(self):
            """
            Keep looping through the search quque until no items remain.
            Use multiprocessing to speed up things
            """
            with ThreadPoolExecutor() as executor:
                while not self.search_queue.empty():
                    executor.submit(self.hex_search)
    
        def run(self):
            self.get_params_list()
    
    
        def debug_print(self):
            print()
            print("search_queue:")
            temp_queue = queue.Queue()
            while not self.search_queue.empty():
                element = self.search_queue.get()
                print(f"  {element}")
                temp_queue.put(element)
            self.search_queue = temp_queue
    
            temp_queue = queue.Queue()
            print("params_list:")
            while not self.params_list.empty():
                param = self.params_list.get()
                print(f"  {param}")
                temp_queue.put(param)
            self.params_list = temp_queue
            print()
    
    
    h = SearchH3Test(
        "85489e37fffffff,85489e27fffffff",
    )
    h.debug_print()
    h.run()
    h.debug_print()
    

    Here is the output:

    search_queue:
      85489e37fffffff
      85489e27fffffff
    params_list:
    
    total=1001
    total=999
    total=143
    total=143
    total=143
    total=143
    total=143
    total=143
    total=143
    
    search_queue:
    params_list:
      {'latitude': 30.439669178763303, 'longitude': -97.74145162032264, 'radius': 10589}
      {'latitude': 30.274024109195224, 'longitude': -97.83825545496954, 'radius': 4047}
      {'latitude': 30.32926630624169, 'longitude': -97.8060126711982, 'radius': 4046}
      {'latitude': 30.331766923236522, 'longitude': -97.73454628471268, 'radius': 4045}
      {'latitude': 30.223829096285723, 'longitude': -97.72765390431441, 'radius': 4046}
      {'latitude': 30.221311831797678, 'longitude': -97.79905214479454, 'radius': 4047}
      {'latitude': 30.276552903546015, 'longitude': -97.76681177354345, 'radius': 4046}
      {'latitude': 30.27904199229733, 'longitude': -97.69539091275956, 'radius': 4045}
    

    Notes

    • I use queue.Queue in place of lists for the reason stated above
    • Most of the changes in the code focus on the differences between a queue and a list. For example list.pop() -> queue.get() or queue.get_nowait()
    • I added the debug_print() method to aid debugging. You can remove it if no longer needed

    Reference

    Update

    In this update, I am still using queue.Queue() to store data, but shift the Queue.get() operation from hex_search to get_params_list().

    In get_params_list(), I will retrieve a cell from the self.search_queue, but this time with a time out of 5 seconds. I also allow time out to occur 3 times before declaring a done.

    Other changes includes adding some random delay to get_items() to simulate upstream delay. I also created a LOGGER object for debugging purpose. A logger is nicer than print: When all is done, I can just set the level to logging.WARN to turn off all the messages. Also, a logger works better with multi processing environment.

    import logging
    import math
    import queue
    import random
    import time
    from concurrent.futures import ThreadPoolExecutor
    
    import h3
    
    # Create a logger
    logging.basicConfig(
        level=logging.DEBUG,
        format="%(levelname)s: %(funcName)s: %(message)s"
    )
    LOGGER = logging.getLogger()
    
    dummy_results = {
        "85489e37fffffff": {"total": 1001},
        "85489e27fffffff": {"total": 999},
        "86489e347ffffff": {"total": 143},
        "86489e34fffffff": {"total": 143},
        "86489e357ffffff": {"total": 143},
        "86489e35fffffff": {"total": 143},
        "86489e367ffffff": {"total": 143},
        "86489e36fffffff": {"total": 143},
        "86489e377ffffff": {"total": 143},
    }
    
    
    class SearchH3Test:
        def __init__(self, origin_cells):
            self.search_queue = queue.Queue()
            for cell in origin_cells.split(","):
                if cell:
                    self.search_queue.put(cell)
            self.params_list = queue.Queue()
    
        def get_h3_radius(self, cell, buffer=False):
            """
            Get the approximate radius of the h3 cell
            """
            return math.ceil(
                math.sqrt((h3.cell_area(cell)) / (1.5 * math.sqrt(3))) * 1000
                + ((100 * (h3.h3_get_resolution(cell) / 10)) if buffer else 0)
            )
    
        def get_items(self, cell):
            """
            Return API items from passed params, including total number of items and a dict of items
    
            r = requests.get(
                url = 'https://someapi.com',
                headers = api_headers,
                params = params
            ).json()
            """
            # Simulate some random delay
            delay = random.randint(1, 10)
            LOGGER.debug("Delay for %d second(s)", delay)
            time.sleep(delay)
            result = dummy_results[cell]
            return result["total"]
    
        def get_hex_params(self, cell):
            """
            Return results from the derived params of the h3 cell
            """
            lat, long = h3.h3_to_geo(cell)
            radius = self.get_h3_radius(cell, buffer=True)
    
            params = {
                "latitude": lat,
                "longitude": long,
                "radius": radius,
            }
    
            total = self.get_items(cell)
            LOGGER.debug("total=%d", total)
    
            return total, params
    
        def hex_search(self, cell):
            """
            Checks if the popped h3 cell produces a total value over 1000.
            If over 1000, get the h3 cell children and append them to the search_queue
            If greater than 0, append params to params_list
            """
            total, params = self.get_hex_params(cell)
            if total > 1000:
                for children in h3.h3_to_children(cell):
                    self.search_queue.put(children)
            elif total > 0:
                self.params_list.put(params)
    
        def get_params_list(self):
            """
            Keep looping through the search quque until no items remain.
            Use multiprocessing to speed up things
            """
            with ThreadPoolExecutor() as executor:
                timeout_count = 0
                while timeout_count < 3:
                    try:
                        cell = self.search_queue.get(timeout=5)
                        executor.submit(self.hex_search, cell)
                    except queue.Empty:
                        timeout_count += 1
    
        def run(self):
            self.get_params_list()
    
    
        def debug_print(self):
            LOGGER.debug("search_queue:")
            temp_queue = queue.Queue()
            while not self.search_queue.empty():
                element = self.search_queue.get()
                LOGGER.debug("  %s", element)
                temp_queue.put(element)
            LOGGER.debug("  <END>")
            self.search_queue = temp_queue
    
            temp_queue = queue.Queue()
            LOGGER.debug("params_list:")
            while not self.params_list.empty():
                param = self.params_list.get()
                LOGGER.debug("  %s", param)
                temp_queue.put(param)
            LOGGER.debug("  <END>")
            self.params_list = temp_queue
    
    
    h = SearchH3Test(
        "85489e37fffffff,85489e27fffffff",
    )
    h.debug_print()
    h.run()
    h.debug_print()