Search code examples
pythonnlplangchainlarge-language-model

Langchain: Custom Output Parser not working with ConversationChain


I am creating a chatbot with langchain's ConversationChain, thus, it needs conversation memory. However, at the end of each of its response, it makes a new line and writes a bunch of gibberish. Thus, I created my custom output parser to remove this gibberish. However, it gives a validation error. I am new to langchain, so any help would be appreciated.

from langchain.chains.conversation.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationChain


from langchain.memory import ConversationBufferMemory

class MyOutputParser:
    def __init__(self):
        pass

    def parse(self, output):
        cut_off = output.find("\n", 3)
        # delete everything after new line
        return output[:cut_off]

template = """You will answer the following questions the best you can, being as informative and factual as possible.
If you don't know, say you don't know. 

Current conversation:
{history}
Human: {input}
AI Assistant:"""

the_output_parser=MyOutputParser()
print(type(the_output_parser))

PROMPT = PromptTemplate(input_variables=["history", "input"], template=template)
conversation = ConversationChain(
    prompt=PROMPT,
    llm=local_llm,
    memory=ConversationBufferWindowMemory(k=4),
    return_final_only=True,
    verbose=False,
    output_parser=the_output_parser,
)

This is the error it gives me:

ValidationError: 1 validation error for ConversationChain
output_parser
  value is not a valid dict (type=type_error.dict)

Solution

  • I'm not sure exactly what you're trying to do, and this area seems to be highly dependent on the version of LangChain you're using, but it seems that your output parser does not follow the method signatures (nor does it inherit from) BaseLLMOutputParser, as it should.

    For LangChain 0.0.261, to fix your specific question about the output parser, try:

    from langchain.chains.conversation.memory import ConversationBufferWindowMemory
    from langchain import PromptTemplate
    from langchain.chains import ConversationChain
    from langchain.schema.output_parser import BaseLLMOutputParser
    
    class MyOutputParser(BaseLLMOutputParser):
    def __init__(self):
        super().__init__()
    
    def parse_result(self, output):
        cut_off = output.find("\n", 3)
        # delete everything after new line
        return output[:cut_off]