Search code examples
pythonconcurrent-processing

Python multi-threading method


I've heard that Python multi-threading is a bit tricky, and I am not sure what is the best way to go about implementing what I need. Let's say I have a function called IO_intensive_function that does some API call which may take a while to get a response.

Say the process of queuing jobs can look something like this:

import thread
for job_args in jobs:
    thread.start_new_thread(IO_intense_function, (job_args))

Would the IO_intense_function now just execute its task in the background and allow me to queue in more jobs?

I also looked at this question, which seems like the approach is to just do the following:

from multiprocessing.dummy import Pool as ThreadPool
pool = ThreadPool(2)
results = pool.map(IO_intensive_function, jobs)

As I don't need those tasks to communicate with each other, the only goal is to send my API requests as fast as possible. Is this the most efficient way? Thanks.

Edit: The way I am making the API request is through a Thrift service.


Solution

  • I had to create code to do something similar recently. I've tried to make it generic below. Note I'm a novice coder, so please forgive the inelegance. What you may find valuable, however, is some of the error processing I found it necessary to embed to capture disconnects, etc.

    I also found it valuable to perform the json processing in a threaded manner. You have the threads working for you, so why go "serial" again for a processing step when you can extract the info in parallel.

    It is possible I will have mis-coded in making it generic. Please don't hesitate to ask follow-ups and I will clarify.

    import requests
    from multiprocessing.dummy import Pool as ThreadPool
    from src_code.config import Config
    
            with open(Config.API_PATH + '/api_security_key.pem') as f:
                my_key = f.read().rstrip("\n")
                f.close()
            base_url = "https://api.my_api_destination.com/v1"
            headers = {"Authorization": "Bearer %s" % my_key}
            itm = list()
            itm.append(base_url)
            itm.append(headers)
    
    
            def call_API(call_var):
                base_url = call_var[0]
                headers = call_var[1]
                call_specific_tag = call_var[2]
    
                endpoint = f'/api_path/{call_specific_tag}'
    
                connection_tries = 0
                for i in range(3):
                    try:
                        dat = requests.get((base_url + endpoint), headers=headers).json()
                    except:
                        connection_tries += 1
                        print(f'Call for {api_specific_tag} failed after {i} attempt(s).  Pausing for 240 seconds.')
                        time.sleep(240)
                    else:
                        break
    
                tag = list()
                vars_to_capture_01 = list()
                vars_to_capture_02 = list()
    
                connection_tries = 0
    
                try:
                    if 'record_id' in dat:
                        vars_to_capture_01.append(dat['record_id'])
                        vars_to_capture_02.append(dat['second_item_of_interest'])
                    else:
                        vars_to_capture_01.append(call_specific_tag)
                        print(f'Call specific tag {call_specific_tag} is unavailable.  Successful pull.')
                        vars_to_capture_02.append(-1)
    
                except:
                        print(f'{call_specific_tag} is unavailable.  Unsuccessful pull.')
                        vars_to_capture_01.append(call_specific_tag)
                        vars_to_capture_02.append(-1)
                        time.sleep(240)
    
                pack = list()
                pack.append(vars_to_capture_01)
                pack.append(vars_to_capture_02)
    
                return pack
    
            vars_to_capture_01 = list()
            vars_to_capture_02 = list()
    
            i = 0
            max_i = len(all_tags)
            while i < max_i:
                ind_rng = range(i, min((i + 10), (max_i)), 1)
                itm_lst = (itm.copy())
                call_var = [itm_lst + [all_tags[q]] for q in ind_rng]
                #packed = call_API(call_var[0]) # for testing of function without pooling
                pool = ThreadPool(len(call_var))
                packed = pool.map(call_API, call_var)
                pool.close()
                pool.join()
                for pack in packed:
                    try:
                        vars_to_capture_01.append(pack[0][0])
                    except:
                        print(f'Unpacking error for {all_tags[i]}.')
                    vars_to_capture_02.append(pack[1][0])