Search code examples
pythonmultiprocessingfastapistarlette

How to parse an unpickleable object from startup event to a function used in multiprocessing pool?


Here are my code:

from multiprocessing import Pool
from functools import partial
from fastapi import FastAPI

clientObject = package.client() # object not pickle able
    
app = FastAPI()

def wrapper(func, retries, data_point):
# wrapper function to add a retry mechanism
    retry = 0
    while retry<retries:
        try:
            result = func(data_point)
        except Exception as err:
            result = err
            time.sleep(5) 
        else:
            break
    return result

def get_response(data_point):
# function use clientObject to get data from an Azure endpoint
    data_point = some_other_processes(data_point)
    ans = clientObject.process(data_point)
    return ans

def main(raw_data):
# main function where I use multiprocessing pool
    list_data_point = preprocess(raw_data)
    with Pool() as pool:
        wrapped_workload = partial(wrapper, 
                                   get_response,
                                   3)
        results = pool.map(wrapped_workload, list_data_point)
        pool.close()
        pool.join()
    return results


@app.post("/get_answer")
def get_answer(raw_data):
    processed_data = main(raw_data)
    return processed_data

The code above works fine if i declare clientObject from the begining (as a global variable). But if i store it as an object in app.statelike this:

@app.on_event("start_up")
def start_connection():
    app.state.clientObject = package.client()

and access it inside get_response function like this:

def get_response(data_point): 
     data_point = some_other_processes(data_point)
     ans = app.state.clientObject.process(data_point)
     return ans

it throws an error: 'State' object has no attribute clientObject. However, the app.state.clientObject is still available inside main() function.

Also, due to the clientObject is not pickleable, I cannot pass it as an argument to get_response(data_point, clientObject) function.

Is there any way that I could initiate the clientObject on startup, store it in a variable and access it from a function used in multiprocessing pool? (without declaring it as global)

Edit: This is my solution followed by suggestion of Frank Yellin below:

def initialize_workers():
    global clientObject
    clientObject = package.client()

def get_response(data_point):
    global clientObject
    data_point = some_other_processes(data_point)
    ans = clientObject.process(data_point)
    return ans

def main(raw_data):
    list_data_point = preprocess(raw_data)
    with Pool(initializer=initialize_workers) as pool:
        wrapped_workload = partial(wrapper, 
                                   get_response,
                                   3)
        results = pool.map(wrapped_workload, list_data_point)
        pool.close()
        pool.join()
    return results

Solution

  • Your code does not work because, as you mentioned clientObject is not a pickleable object. It is not an object that you can just copy from one process and use in another. Each process needs its own clientObject.

    So you need to make sure that each process, on start up, creates the clientObject. The way your code words is certainly one way.

    An alternative, way is to use the initializer and initargs arguments when you create the Pool. This specifies a function and its arguments that is called each time a new process starts. This function would be a good place to create your clientObject and store it in a global or in your app.state.