Search code examples
databricksazure-databricksmlflow

How to stream data from a Databricks model serving endpoint?


I’m developing an application leveraging a Custom PyFunc from MLflow. The application includes several components such as query rephrasing, intent detection, chunk retrieval, and response generation. The overall process takes about 7-8 seconds to complete, with the first 3-4 seconds used for chunk retrieval and the remaining 3-4 seconds for generating the answer.

To improve the user experience, I want to start streaming the final response from OpenAI from the 4th second onwards, so the user won't have to wait the full 7-8 seconds. However, I encountered an issue where I can't return a generator from the normal predict method in MLflow, receiving the following error:

mlflow.exceptions.MlflowException: Encountered an unexpected error while converting model response to JSON. Error: 'Object of type generator is not JSON serializable.'

I found that predict_stream in MLflow docs can return a generator. This works well with load_model within a notebook, but when I update the model in the model serving endpoint, I get None as the response.

Here is a custom example I found in the MLflow documentation:

import mlflow

# Define a custom model that supports streaming
class StreamableModel(mlflow.pyfunc.PythonModel):
    def predict_stream(self, context, model_input, params=None):
        # Yielding elements one at a time
        for element in ["a", "b", "c", "d", "e"]:
            yield element

Here's the code to log the model:

with mlflow.start_run():
    # Log the custom Python model
    mlflow.pyfunc.log_model("stream_model", python_model=StreamableModel(),registered_model_name = "stream_model",streamable=True)

Is there a way we can return a generator from model serving endpoint ?


Solution

  • As per this documentation the return type should be Iterator[Union[Dict[str, Any], str]] with dictionary or string.

    So, the generator object type doesn't work.

    Even though you add a function like below.

    def predict_stream(self, context, model_input, params=None):
            return iter([1,3,2,3])
    

    The scoring server tries to call predict function itself and not the predict_stream.

    You can see in below image, the output is null.

    You can refer this regarding the above issue.

    The community is trying to add this support till then you use the predict function.