Search code examples
langchainlarge-language-modelchain

How to get intermediary chain step outputs in final output?


For simplicity sake I have the following chain:

  1. extract names from a list
  2. validate these names against a second list of names

What I want is to receive a JSON with all intermediary steps at the end as well as the input:

{"first_value": "Dave, John, carrot", "first_prompt_output" "Dave, John", "possible_values": "John"...}

But I am confused by the LC docs and I seem to be able to get all of the inputs using the runnable passthrough, but in a hard to read format. I fiddled with it for half an hour (tried RunnableParallel, RunnablePassthrough.assign(), ...) but I don't seem to be able to get it, there must be some key feature I'm missing.

{'result': {'third_prompt_output': 'Johny'},
 'first_prompt_output': {'first_prompt_output': {'first_prompt_output': ['Dave',
    'John']},
  'possible_values': {'first_value': 'Dave, John',
   'possible_values': ['John']}}}
first_prompt = PromptTemplate.from_template("""Find all names in the following text and extract it as JSON with a field `first_prompt_output`: {first_value}"
                                                    first_prompt_output:""")

second_prompt = PromptTemplate.from_template("""Here is a list of possible values: {possible_values} and a list of found value {first_prompt_output}. Find values that are in both lists. Return a JSON with the fields `first_prompt_output` and `second_prompt_output` and `possible_values`.""")

first_value = "Dave, John"
possible_values = ["John"]

first_chain = (
    first_prompt
    | llm
    | SimpleJsonOutputParser()
)

second_chain = (
    second_prompt
    | llm
    | SimpleJsonOutputParser()
)

chain = (
    {"first_prompt_output": first_chain, "possible_values": RunnablePassthrough(), "first_value": RunnablePassthrough()} 
    | RunnableParallel(result={"second_prompt_output": second_chain, "first_value": itemgetter("first_value")})
)

chain.invoke({"first_value": first_value, "possible_values": possible_values})

I tried using the RunnableParallel or RunnablePassthrough.assign, but neither does what I expect it to do. What I basically need is a dict.update() but in the pipeline.


Solution

  • I have not found a way to do this with LangChain, but I found a function that allows me to flatten the output and results in what I want, although it seems a bit clunky and I believe there must be a better solution.

    The key is to add the following function to the chain:

    def flatten_dict(*vars) -> dict:
        '''
        Flatten a dictionary by removing unnecessary mid-level keys.
        Returns a Runnable (chainable) function.
        '''
        flat = {}
        for var in vars:
            keys = [k for k in var]
            for key in keys:
                if isinstance(var[key], dict):
                    flat.update(var[key])
                else:
                    flat[key] = var[key]
        return flat
    
    chain = (
        {"first_prompt_output": first_chain, "possible_values": RunnablePassthrough(), "first_value": RunnablePassthrough()} 
        | RunnableParallel(result={"second_prompt_output": second_chain, "first_value": itemgetter("first_value")})
    )
    | flatten_dict