Search code examples
azure-aiazure-promptflow

How to Consume Registered Connections in Function Calling with PromptFlow in AI Studio?


I'm currently developing a PromptFlow flow using function calling in AI Studio and I'm encountering a challenge with consuming registered connections within functions.

Context

In non-function calling Python nodes, I can easily use the @tool decorator to add connections as arguments. This allows me to inject the connections directly into the functions. Problem: When switching to function calling within LLM nodes, this approach doesn't seem to work. The LLM node calling the function does not know about the connections argument.

Specific Scenario

I'm building a flow where an LLM node calls a function to retrieve data from a database. The function prototype runs successfully, but I run into issues when trying to use connections within these functions. Here is a simplified version of my code:

def run_db_lookup(message: str, db_conn: CustomConnection, aoai_conn: AzureOpenAIConnection) -> str:

    db_client = # require credentials here

    # Your initial question's embedding
    client = AzureOpenAI(
    api_key = aoai_conn.key,  
    api_version = aoai_conn.version,
    azure_endpoint = aoai_conn.base
    )

    # Function implementation

    return output


@tool
def run_function(response_message: dict) -> str:
    function_call = response_message.get("function_call", None)
    if function_call and "name" in function_call and "arguments" in function_call:
        function_name = function_call["name"]
        function_args = json.loads(function_call["arguments"])
        print(function_args)
        result = globals()[function_name](**function_args)
    else:
        print("No function call")
        if isinstance(response_message, dict):
            result = response_message.get("content", "")
        else:
            result = response_message
    return result

When attempting to run this with function calling, I get (of course) the following error: run_db_lookup() missing 2 required positional arguments: 'db_conn' and 'aoai_conn'. I am unsure how to pass these static arguments (connections) to the function.

Question

What are the best practices for consuming registered connections in functions when using function calling within PromptFlow? How can I properly inject or pass these connections to my function so that the LLM node recognizes them?

Additional Information

Tools Used: AI Studio, PromptFlow


Solution

  • Achieved it this way. Not sure about the best practices, however.

    def run_db_lookup(message: str, db_conn: CustomConnection, aoai_conn: AzureOpenAIConnection) -> str:
    
        db_client = # database connection
    
        # Your initial question's embedding
        client = AzureOpenAI(
            api_key = aoai_conn.api_key,  
            api_version = aoai_conn.api_version,
            azure_endpoint = aoai_conn.api_base
        )
    
        response = client.embeddings.create(
            input = message,
            model= "text-embedding-ada-002"
        )
        input_vec = response.data[0].embedding
    
        # Get few-shot examples from vector database
        response = db_similarity_search(input_vec)
    
        return response
    
    
    @tool
    def run_function(response_message: dict, db_conn: CustomConnection, aoai_conn: AzureOpenAIConnection) -> str:
        function_call = response_message.get("function_call", None)
        if function_call and "name" in function_call and "arguments" in function_call:
            function_name = function_call["name"]
            function_args = json.loads(function_call["arguments"])
    
            # Pass db_conn and aoai_conn to run_db_lookup
            if function_name == "run_db_lookup":
                result = run_db_lookup(**function_args, db_conn=db_conn, aoai_conn=aoai_conn)
            else:
                result = globals()[function_name](**function_args)
        else:
            print("No function call")
            if isinstance(response_message, dict):
                result = response_message.get("content", "")
            else:
                result = response_message
        return result