For simplicity sake I have the following chain:
What I want is to receive a JSON with all intermediary steps at the end as well as the input:
{"first_value": "Dave, John, carrot", "first_prompt_output" "Dave, John", "possible_values": "John"...}
But I am confused by the LC docs and I seem to be able to get all of the inputs using the runnable passthrough, but in a hard to read format. I fiddled with it for half an hour (tried RunnableParallel, RunnablePassthrough.assign(), ...) but I don't seem to be able to get it, there must be some key feature I'm missing.
{'result': {'third_prompt_output': 'Johny'},
'first_prompt_output': {'first_prompt_output': {'first_prompt_output': ['Dave',
'John']},
'possible_values': {'first_value': 'Dave, John',
'possible_values': ['John']}}}
first_prompt = PromptTemplate.from_template("""Find all names in the following text and extract it as JSON with a field `first_prompt_output`: {first_value}"
first_prompt_output:""")
second_prompt = PromptTemplate.from_template("""Here is a list of possible values: {possible_values} and a list of found value {first_prompt_output}. Find values that are in both lists. Return a JSON with the fields `first_prompt_output` and `second_prompt_output` and `possible_values`.""")
first_value = "Dave, John"
possible_values = ["John"]
first_chain = (
first_prompt
| llm
| SimpleJsonOutputParser()
)
second_chain = (
second_prompt
| llm
| SimpleJsonOutputParser()
)
chain = (
{"first_prompt_output": first_chain, "possible_values": RunnablePassthrough(), "first_value": RunnablePassthrough()}
| RunnableParallel(result={"second_prompt_output": second_chain, "first_value": itemgetter("first_value")})
)
chain.invoke({"first_value": first_value, "possible_values": possible_values})
I tried using the RunnableParallel or RunnablePassthrough.assign, but neither does what I expect it to do. What I basically need is a dict.update()
but in the pipeline.
I have not found a way to do this with LangChain, but I found a function that allows me to flatten the output and results in what I want, although it seems a bit clunky and I believe there must be a better solution.
The key is to add the following function to the chain:
def flatten_dict(*vars) -> dict:
'''
Flatten a dictionary by removing unnecessary mid-level keys.
Returns a Runnable (chainable) function.
'''
flat = {}
for var in vars:
keys = [k for k in var]
for key in keys:
if isinstance(var[key], dict):
flat.update(var[key])
else:
flat[key] = var[key]
return flat
chain = (
{"first_prompt_output": first_chain, "possible_values": RunnablePassthrough(), "first_value": RunnablePassthrough()}
| RunnableParallel(result={"second_prompt_output": second_chain, "first_value": itemgetter("first_value")})
)
| flatten_dict