I've been reviewing the Depends docs, official example
from typing import Annotated
from fastapi import Depends, FastAPI
app = FastAPI()
async def common_parameters(q: str | None = None, skip: int = 0, limit: int = 100):
return {"q": q, "skip": skip, "limit": limit}
@app.get("/items/")
async def read_items(commons: Annotated[dict, Depends(common_parameters)]):
return commons
However, in my use case, I need to serve an ML model, which will be updated at a recurring cadence (hourly, daily, etc.) The solution from docs (above) depends on a callable function; I believe that it is cached not generated each time. Nonetheless, my use case is not some scaffolding that needs to go up/down with each invocation. But rather, I need a custom class with state. The idea is that the ML model (a class attribute) can be updated scheduled and/or async and the ./invocations/
method will serve said model, reflecting updates as they occur.
In current state, I use global variables. This works well when my entire application fits on a single script. However, as my application grows, I will be interested in using the router yet I'm concerned that global state
will cause failures.
Is there an appropriate way to pass a stateful instance of a class object across methods?
See example class and method
class StateManager:
def __init__(self):
self.bucket = os.environ.get("BUCKET_NAME", "artifacts_bucket")
self.s3_model_path = "./model.joblib"
self.local_model_path = './model.joblib'
def get_clients(self):
self.s3 = boto3.client('s3')
def download_model(self):
self.s3.download_file(self.bucket, self.s3_model_path, self.local_model_path)
self.model = joblib.load(self.local_model_path)
...
state = StateManager()
state.download_model()
...
@app.post("/invocations")
def invocations(request: InferenceRequest):
input_data = pd.DataFrame(dict(request), index=[0])
try:
predictions = state.model.predict(input_data)
return JSONResponse({"predictions": predictions.tolist()},
status_code=status.HTTP_200_OK)
except Exception as e:
return JSONResponse({"error": str(e)},
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
Is there an appropriate way to pass a stateful instance of a class object across methods?
Yes, you could use lifespan for this:
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, Request
from typing import AsyncIterator, TypedDict
import joblib
import boto3
import os
class State(TypedDict):
model: any
class StateManager:
def __init__(self):
self.bucket = os.environ.get("BUCKET_NAME", "artifacts_bucket")
self.s3_model_path= "./model.joblib"
self.local_model_path = './model.joblib'
def get_clients(self):
self.s3 = boto3.client('s3')
def download_model(self):
self.s3.download_file(self.bucket, self.s3_model_path, self.local_model_path)
self.model = joblib.load(self.local_model_path)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[State]:
state_manager =StateManager()
state_manager.get_clients()
state_manager.download_model()
state: State ={
"model": state_manager.model
}
yield state
app = FastAPI(lifespan=lifespan)
After this you could add method for getting model dependency:
from fastapi import Request
def get_model(request: Request):
return request.app.state.model
And finally use it from your router:
@app.post("/invocations")
async def invocations(request: InferenceRequest, model = Depends(get_model)):
...
With this solution StateManager would be initialised only once and would be used for all routers with specific state.
Next steps really depends on the way of your periodical state updating
You could add periodical update right into your StateManager:
class StateManager:
...
async def update_model_periodically(self, interval: int = 3600):
while True:
self.download_model()
print("Model updated...")
await asyncio.sleep(interval)
And add this periodical task to lifespan:
@asynccontextmanager
async def lifespan(app: FastAPI):
...
asyncio.create_task(state_manager.update_model_periodically(interval=3600))