Search code examples
pythonpython-asynciofastapipython-multiprocessing

Locking resource in FastAPI - using a multiprocessing Worker


I would like to make an FastAPI service with one /get endpoint which will return a ML-model inference result. It is pretty easy to implement that, but the catch is I periodically need to update the model with a newer version (trough request on another server with models, but that is beside the point), and here I see a problem!

What will happen if one request calls old model, but the old model is currently being replaced by a newer one?? How can I implement this kind of locking mechanism with asyncio ?

Here is the code:

import asyncio
import time
from concurrent.futures import ProcessPoolExecutor

from fastapi import FastAPI, Request
from sentence_transformers import SentenceTransformer

app = FastAPI()
sbertmodel = None


def create_model():
    global sbertmodel
    sbertmodel = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')


# if you try to run all predicts concurrently, it will result in CPU trashing.
pool = ProcessPoolExecutor(max_workers=1, initializer=create_model)


def model_predict():
    ts = time.time()
    vector = sbertmodel.encode('How big is London')
    return vector


async def vector_search(vector):
    # simulate I/O call (e.g. Vector Similarity Search using a VectorDB)
    await asyncio.sleep(0.005)


@app.get("/")
async def entrypoint(request: Request):
    loop = asyncio.get_event_loop()
    ts = time.time()
    # worker should be initialized outside endpoint to avoid cold start
    vector = await loop.run_in_executor(pool, model_predict)
    print(f"Model  : {int((time.time() - ts) * 1000)}ms")
    ts = time.time()
    await vector_search(vector)
    print(f"io task: {int((time.time() - ts) * 1000)}ms")
    return "ok"

My model update would be implemented trough Repeated tasks (but that is not important now) : https://fastapi-utils.davidmontague.xyz/user-guide/repeated-tasks/

This is the idea of a model serving : https://luis-sena.medium.com/how-to-optimize-fastapi-for-ml-model-serving-6f75fb9e040d

EDIT: what is important to run multiple requests concurrently, and while model is updating, acquire lock so that requests wouldnt fail, they should just wait a little bit longer because it is a small model.


Solution

  • Thanks for your snippet. With it visible, it is possible to write a proposal for what you need there - as it turns out, you need to update the model in a subprocess, and there is nothing to worry about in the main-process async part of the code. Signaling the worker processes for the updates, though, needs some attention.

    Since you are using ProcessPool workers, you need a way to expose variables from the root process that the process workers can "see" -

    Python has this in the form of multiprocessing.Manager objects -

    Bellow i pick your code and add the parts needed for your requisites of "no imediate, but no conflicting" updatding of the model in use. As it turns out, once we have variables that can be seen in the worker, all that is needed is a check in the model-runner method itself to see if the model needs to be updated.\

    I didn't run this snippet - so there might be some typo in variable names or even one or other missing parenthesis - use as a model, not "copy + paste" (but I tested the "moving parts" of Manager.Namespace() objects and passing then as parameters as initargs in a ProcessPoolExecutor)

    import asyncio
    import time
    import threading
    from concurrent.futures import ProcessPoolExecutor
    from multiprocessing import Manager
    
    
    from fastapi import FastAPI, Request
    from sentence_transformers import SentenceTransformer
    
    sbertmodel = None
    local_model_iteration = -1
    shared_namespace = None
    
    # pool, and other multi-processing objects can`t simply
    # be started in the top level of the body, or they't be re
    # created in each subprocess!!
    # check https://fastapi.tiangolo.com/advanced/events/#lifespan
    
    
    @asynccontextmanager
    async def lifespan(app: FastAPI):
        global pool, root_namespace
        manager = Manager()
    
        root_namespace = manager.NameSpace()
        
        # Values assigned to the "namespace" object are 
        # visible on the subprocess created by the pool
        
        root_namspace.model_iteration = 0
        root_namespace.model_parameters = "multi-qa-MiniLM-L6-cos-v1"
        
        # (as long as we send the namespace object to each subprocess
        # and store it there)
        pool = ProcessPoolExecutor(max_workers=1, initializer=initialize_subprocess, initargs=(root_namespace,))
        with pool, manager:
            # pass control to fastapi: all the app is executed
            yield
        # end of "with" block:
        # both the pool and manager are shutdown when fastapi server exits!
        
    
    app = FastAPI(lifespan=lifespan)
    
    # if you try to run all predicts concurrently, it will result in CPU trashing.
    
    
    def initialize_subprocess(shared_namespace_arg):
        global shared_namespace
        # Store the shared namespace in _this_ process:
        shared_namespace = shared_namespac_arg
        update_model()
        
    def update_model():
        "called on worker subprocess start, and at any time the model is outdated" 
        global local_model_iteration, sbertmodel
        local_model_iteration = shared_namespace.model_iteration
        # retrieve parameter posted by root process:
        sbertmodel = SentenceTransformer(shared_namespace.model_parameters)
    
    
    
    def model_predict():
        ts = time.time()
        # verify if model was updatd from the root process
        if shared_namespace.model_iteration > local_model_iteration:
            # if so, just update the model
            update_model()
        # model is synchronied, just do our job:
        vector = sbertmodel.encode('How big is London')
        return vector
    
    
    async def vector_search(vector):
        # simulate I/O call (e.g. Vector Similarity Search using a VectorDB)
        await asyncio.sleep(0.005)
    
    
    @app.get("/")
    async def entrypoint(request: Request):
        loop = asyncio.get_event_loop()
        ts = time.time()
        # worker should be initialized outside endpoint to avoid cold start
        vector = await loop.run_in_executor(pool, model_predict)
        print(f"Model  : {int((time.time() - ts) * 1000)}ms")
        ts = time.time()
        await vector_search(vector)
        print(f"io task: {int((time.time() - ts) * 1000)}ms")
        return "ok"
    
    @app.get("/update_model")
    async def update_model_endpoint(request: Request):
        # extract from the request the needed paramters for the new model
        ...
        new_model_parameters = ...
        # uodate the model parameters and model iteration so they are visible
        # in the worker(s)
        root_namespace.model_parameters = new_model_parameters
        # This increment taking place _after_ the "model_parameters" are set 
        # is all that is needed to keep things running in order here:
        root_namespace.model_iteration += 1
        return {} # whatever response needed by the endpoint