I'm trying to speed up some API calls by using ThreadPoolExecutor
. I have a class that accepts a string list of h3 cells like cell1,cell2
. h3 uses hexagons at different resolutions to get finer detail in mapping. The class methods take the returned cells and gets information about them that is passed to an API with params. The API will return a total number of results (could be over 1000). Because the API is limited to returning at most the first 1000 results through pagination, I utilize h3 to zoom into each cell until all of its children/grandchildren/etc have a total number of results under 1000. This is effectively doing BFS from the original cells provided.
When running this code with the run
method, the expectation is that the search_queue
would be empty as all cells have been processed. However, with the way its set up currently, only the origin_cells
provided to the class get processed and retrieving search_queue
shows unprocessed items. Swapping the while
and ThreadPoolExecutor
lines does run everything as expected, but it runs at the same speed as without using ThreadPoolExecutor
.
Is there a way to make the multiprocessing work as expected?
Edit with working example
import h3
import math
import requests
from concurrent.futures import ThreadPoolExecutor
from time import sleep
dummy_results = {
'85489e37fffffff': {'total': 1001},
'85489e27fffffff': {'total': 999},
'86489e347ffffff': {'total': 143},
'86489e34fffffff': {'total': 143},
'86489e357ffffff': {'total': 143},
'86489e35fffffff': {'total': 143},
'86489e367ffffff': {'total': 143},
'86489e36fffffff': {'total': 143},
'86489e377ffffff': {'total': 143},
}
class SearchH3Test(object):
def __init__(self, origin_cells):
self.search_queue = list(filter(None, origin_cells.split(',')))
self.params_list = []
def get_h3_radius(self, cell, buffer=False):
"""
Get the approximate radius of the h3 cell
"""
return math.ceil(
math.sqrt(
(h3.cell_area(cell))/(1.5*math.sqrt(3))
)*1000
+ ((100*(h3.h3_get_resolution(cell)/10)) if buffer else 0)
)
def get_items(self, cell):
"""
Return API items from passed params, including total number of items and a dict of items
r = requests.get(
url = 'https://someapi.com',
headers = api_headers,
params = params
).json()
"""
sleep(1)
r = dummy_results[cell]
return r['total']
def get_hex_params(self, cell):
"""
Return results from the derived params of the h3 cell
"""
lat, long = h3.h3_to_geo(cell)
radius = self.get_h3_radius(cell, buffer=True)
params = {
'latitude': lat,
'longitude': long,
'radius': radius,
}
total = self.get_items(cell)
print(total)
return total, params
def hex_search(self):
"""
Checks if the popped h3 cell produces a total value over 1000.
If over 1000, get the h3 cell children and append them to the search_queue
If greater than 0, append params to params_list
"""
cell = self.search_queue.pop(0)
total, params = self.get_hex_params(cell)
if total > 1000:
self.search_queue.extend(list(h3.h3_to_children(cell)))
elif total > 0:
self.params_list.append(params)
def get_params_list(self):
"""
Keep looping through the search quque until no items remain.
Use multiprocessing to speed up things
"""
with ThreadPoolExecutor() as e:
while self.search_queue:
e.submit(self.hex_search)
def run(self):
self.get_params_list()
h = SearchH3Test(
'85489e37fffffff,85489e27fffffff',
)
h.run()
len(h.search_queue) # returns 7 for the children that weren't processed as expected
len(h.params_list) # returns 1 for the cell under 1000
When dealing with multiple threads/process, you must use a queue data structure, which is thread safe. Here is my rewrite of your code:
import math
import queue
from concurrent.futures import ThreadPoolExecutor
import h3
dummy_results = {
"85489e37fffffff": {"total": 1001},
"85489e27fffffff": {"total": 999},
"86489e347ffffff": {"total": 143},
"86489e34fffffff": {"total": 143},
"86489e357ffffff": {"total": 143},
"86489e35fffffff": {"total": 143},
"86489e367ffffff": {"total": 143},
"86489e36fffffff": {"total": 143},
"86489e377ffffff": {"total": 143},
}
class SearchH3Test:
def __init__(self, origin_cells):
self.search_queue = queue.Queue()
for cell in origin_cells.split(","):
if cell:
self.search_queue.put(cell)
self.params_list = queue.Queue()
def get_h3_radius(self, cell, buffer=False):
"""
Get the approximate radius of the h3 cell
"""
return math.ceil(
math.sqrt((h3.cell_area(cell)) / (1.5 * math.sqrt(3))) * 1000
+ ((100 * (h3.h3_get_resolution(cell) / 10)) if buffer else 0)
)
def get_items(self, cell):
"""
Return API items from passed params, including total number of items and a dict of items
r = requests.get(
url = 'https://someapi.com',
headers = api_headers,
params = params
).json()
"""
result = dummy_results[cell]
return result["total"]
def get_hex_params(self, cell):
"""
Return results from the derived params of the h3 cell
"""
lat, long = h3.h3_to_geo(cell)
radius = self.get_h3_radius(cell, buffer=True)
params = {
"latitude": lat,
"longitude": long,
"radius": radius,
}
total = self.get_items(cell)
print(f"{total=}")
return total, params
def hex_search(self):
"""
Checks if the popped h3 cell produces a total value over 1000.
If over 1000, get the h3 cell children and append them to the search_queue
If greater than 0, append params to params_list
"""
try:
cell = self.search_queue.get_nowait()
except queue.Empty:
return
total, params = self.get_hex_params(cell)
if total > 1000:
for children in h3.h3_to_children(cell):
self.search_queue.put(children)
elif total > 0:
self.params_list.put(params)
def get_params_list(self):
"""
Keep looping through the search quque until no items remain.
Use multiprocessing to speed up things
"""
with ThreadPoolExecutor() as executor:
while not self.search_queue.empty():
executor.submit(self.hex_search)
def run(self):
self.get_params_list()
def debug_print(self):
print()
print("search_queue:")
temp_queue = queue.Queue()
while not self.search_queue.empty():
element = self.search_queue.get()
print(f" {element}")
temp_queue.put(element)
self.search_queue = temp_queue
temp_queue = queue.Queue()
print("params_list:")
while not self.params_list.empty():
param = self.params_list.get()
print(f" {param}")
temp_queue.put(param)
self.params_list = temp_queue
print()
h = SearchH3Test(
"85489e37fffffff,85489e27fffffff",
)
h.debug_print()
h.run()
h.debug_print()
Here is the output:
search_queue:
85489e37fffffff
85489e27fffffff
params_list:
total=1001
total=999
total=143
total=143
total=143
total=143
total=143
total=143
total=143
search_queue:
params_list:
{'latitude': 30.439669178763303, 'longitude': -97.74145162032264, 'radius': 10589}
{'latitude': 30.274024109195224, 'longitude': -97.83825545496954, 'radius': 4047}
{'latitude': 30.32926630624169, 'longitude': -97.8060126711982, 'radius': 4046}
{'latitude': 30.331766923236522, 'longitude': -97.73454628471268, 'radius': 4045}
{'latitude': 30.223829096285723, 'longitude': -97.72765390431441, 'radius': 4046}
{'latitude': 30.221311831797678, 'longitude': -97.79905214479454, 'radius': 4047}
{'latitude': 30.276552903546015, 'longitude': -97.76681177354345, 'radius': 4046}
{'latitude': 30.27904199229733, 'longitude': -97.69539091275956, 'radius': 4045}
queue.Queue
in place of lists for the reason stated abovelist.pop()
-> queue.get()
or queue.get_nowait()
debug_print()
method to aid debugging. You can remove it if no longer neededIn this update, I am still using queue.Queue()
to store data, but shift the Queue.get()
operation from hex_search
to get_params_list()
.
In get_params_list()
, I will retrieve a cell from the self.search_queue
, but this time with a time out of 5 seconds. I also allow time out to occur 3 times before declaring a done.
Other changes includes adding some random delay to get_items()
to simulate upstream delay. I also created a LOGGER
object for debugging purpose. A logger is nicer than print: When all is done, I can just set the level to logging.WARN
to turn off all the messages. Also, a logger works better with multi processing environment.
import logging
import math
import queue
import random
import time
from concurrent.futures import ThreadPoolExecutor
import h3
# Create a logger
logging.basicConfig(
level=logging.DEBUG,
format="%(levelname)s: %(funcName)s: %(message)s"
)
LOGGER = logging.getLogger()
dummy_results = {
"85489e37fffffff": {"total": 1001},
"85489e27fffffff": {"total": 999},
"86489e347ffffff": {"total": 143},
"86489e34fffffff": {"total": 143},
"86489e357ffffff": {"total": 143},
"86489e35fffffff": {"total": 143},
"86489e367ffffff": {"total": 143},
"86489e36fffffff": {"total": 143},
"86489e377ffffff": {"total": 143},
}
class SearchH3Test:
def __init__(self, origin_cells):
self.search_queue = queue.Queue()
for cell in origin_cells.split(","):
if cell:
self.search_queue.put(cell)
self.params_list = queue.Queue()
def get_h3_radius(self, cell, buffer=False):
"""
Get the approximate radius of the h3 cell
"""
return math.ceil(
math.sqrt((h3.cell_area(cell)) / (1.5 * math.sqrt(3))) * 1000
+ ((100 * (h3.h3_get_resolution(cell) / 10)) if buffer else 0)
)
def get_items(self, cell):
"""
Return API items from passed params, including total number of items and a dict of items
r = requests.get(
url = 'https://someapi.com',
headers = api_headers,
params = params
).json()
"""
# Simulate some random delay
delay = random.randint(1, 10)
LOGGER.debug("Delay for %d second(s)", delay)
time.sleep(delay)
result = dummy_results[cell]
return result["total"]
def get_hex_params(self, cell):
"""
Return results from the derived params of the h3 cell
"""
lat, long = h3.h3_to_geo(cell)
radius = self.get_h3_radius(cell, buffer=True)
params = {
"latitude": lat,
"longitude": long,
"radius": radius,
}
total = self.get_items(cell)
LOGGER.debug("total=%d", total)
return total, params
def hex_search(self, cell):
"""
Checks if the popped h3 cell produces a total value over 1000.
If over 1000, get the h3 cell children and append them to the search_queue
If greater than 0, append params to params_list
"""
total, params = self.get_hex_params(cell)
if total > 1000:
for children in h3.h3_to_children(cell):
self.search_queue.put(children)
elif total > 0:
self.params_list.put(params)
def get_params_list(self):
"""
Keep looping through the search quque until no items remain.
Use multiprocessing to speed up things
"""
with ThreadPoolExecutor() as executor:
timeout_count = 0
while timeout_count < 3:
try:
cell = self.search_queue.get(timeout=5)
executor.submit(self.hex_search, cell)
except queue.Empty:
timeout_count += 1
def run(self):
self.get_params_list()
def debug_print(self):
LOGGER.debug("search_queue:")
temp_queue = queue.Queue()
while not self.search_queue.empty():
element = self.search_queue.get()
LOGGER.debug(" %s", element)
temp_queue.put(element)
LOGGER.debug(" <END>")
self.search_queue = temp_queue
temp_queue = queue.Queue()
LOGGER.debug("params_list:")
while not self.params_list.empty():
param = self.params_list.get()
LOGGER.debug(" %s", param)
temp_queue.put(param)
LOGGER.debug(" <END>")
self.params_list = temp_queue
h = SearchH3Test(
"85489e37fffffff,85489e27fffffff",
)
h.debug_print()
h.run()
h.debug_print()