Search code examples
pythonfastapi

FastAPI stateful dependencies


I've been reviewing the Depends docs, official example

from typing import Annotated

from fastapi import Depends, FastAPI

app = FastAPI()


async def common_parameters(q: str | None = None, skip: int = 0, limit: int = 100):
    return {"q": q, "skip": skip, "limit": limit}


@app.get("/items/")
async def read_items(commons: Annotated[dict, Depends(common_parameters)]):
    return commons

However, in my use case, I need to serve an ML model, which will be updated at a recurring cadence (hourly, daily, etc.) The solution from docs (above) depends on a callable function; I believe that it is cached not generated each time. Nonetheless, my use case is not some scaffolding that needs to go up/down with each invocation. But rather, I need a custom class with state. The idea is that the ML model (a class attribute) can be updated scheduled and/or async and the ./invocations/ method will serve said model, reflecting updates as they occur.

In current state, I use global variables. This works well when my entire application fits on a single script. However, as my application grows, I will be interested in using the router yet I'm concerned that global state will cause failures.

Is there an appropriate way to pass a stateful instance of a class object across methods?

See example class and method

class StateManager:
    def __init__(self):
        self.bucket = os.environ.get("BUCKET_NAME", "artifacts_bucket")
        self.s3_model_path = "./model.joblib"
        self.local_model_path = './model.joblib'

    def get_clients(self):
        self.s3 = boto3.client('s3')

    def download_model(self):
        self.s3.download_file(self.bucket, self.s3_model_path, self.local_model_path)
        self.model = joblib.load(self.local_model_path)

...

state = StateManager()
state.download_model()

...


@app.post("/invocations")
def invocations(request: InferenceRequest):
    input_data = pd.DataFrame(dict(request), index=[0])
    try:        
        predictions = state.model.predict(input_data)
        return JSONResponse({"predictions": predictions.tolist()},
                            status_code=status.HTTP_200_OK)
    except Exception as e:
        return JSONResponse({"error": str(e)},
                            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)


Solution

  • Is there an appropriate way to pass a stateful instance of a class object across methods?

    Yes, you could use lifespan for this:

    import asyncio
    from contextlib import asynccontextmanager
    from fastapi import FastAPI, Depends, Request
    from typing import AsyncIterator, TypedDict
    import joblib
    import boto3
    import os
    
    class State(TypedDict):
        model: any
    
    class StateManager:
        def __init__(self):
            self.bucket = os.environ.get("BUCKET_NAME", "artifacts_bucket")
            self.s3_model_path= "./model.joblib"
            self.local_model_path = './model.joblib'
    
        def get_clients(self):
            self.s3 = boto3.client('s3')
    
        def download_model(self):
            self.s3.download_file(self.bucket, self.s3_model_path, self.local_model_path)
            self.model = joblib.load(self.local_model_path)
    
    @asynccontextmanager
    async def lifespan(app: FastAPI) -> AsyncIterator[State]:
        state_manager =StateManager()
        state_manager.get_clients()
        state_manager.download_model()
        
        state: State ={
            "model": state_manager.model
        }
        yield state
    
    app = FastAPI(lifespan=lifespan)
    

    After this you could add method for getting model dependency:

    from fastapi import Request
    
    
    def get_model(request: Request):
        return request.app.state.model
    

    And finally use it from your router:

    @app.post("/invocations")
    async def invocations(request: InferenceRequest, model = Depends(get_model)):
        ...
    

    With this solution StateManager would be initialised only once and would be used for all routers with specific state.

    Next steps really depends on the way of your periodical state updating

    You could add periodical update right into your StateManager:

    class StateManager:
        ...
        async def update_model_periodically(self, interval: int = 3600):
            while True:
                self.download_model()
                print("Model updated...")
                await asyncio.sleep(interval)
    

    And add this periodical task to lifespan:

    @asynccontextmanager
    async def lifespan(app: FastAPI):
        ...
        asyncio.create_task(state_manager.update_model_periodically(interval=3600))