Search code examples
langchain

Dynamic filtering with ConversationalRetrievalChain


I'm trying to add metadata filtering of the underlying vector store (chroma).

db = Chroma.from_documents(texts, embeddings)

It works like this:

qa = ConversationalRetrievalChain.from_llm(
    OpenAI(openai_api_key=get_random_key(OPENAI_API_KEY_POOL), cache=True, temperature=0),
    VectorStoreRetriever(vectorstore=db, search_kwargs={"filter": {"source": "data/my.pdf"}}),
    verbose=True,
    return_source_documents=True) 
result = qa({"question": query, "chat_history": []})

But this would imply creating a separate chain for each document which seems weird. However, when I try to pass the filter to the existing chain, it doesn't seem to have any effect, it returns results for all the documents in the db.

qa = ConversationalRetrievalChain.from_llm(
    OpenAI(openai_api_key=get_random_key(OPENAI_API_KEY_POOL), cache=True, temperature=0),
    VectorStoreRetriever(vectorstore=db),
    verbose=True,
    return_source_documents=True)

filter = {'source': 'my.pdf'}

result = qa({"question": query, "chat_history": [], "filter": filter})

Am I missing something? Or it really won't work without extending the existing classes/modifying source code of langchain?


Solution

  • I ended up extending both classes I used to pass the filter:

    class ConversationalRetrievalChainPassArgs(ConversationalRetrievalChain):
    
        def _get_docs(self, question: str, inputs: Dict[str, Any], *,
            run_manager: CallbackManagerForChainRun) -> List[Document]:
            """Get docs."""
            docs = self.retriever._get_relevant_documents(
                question, inputs['filter']
            )
            return self._reduce_tokens_below_limit(docs)
    
    
    class VectorStoreRetrieverWithFiltering(VectorStoreRetriever):
    
        def _get_relevant_documents(self, query: str, filter: dict) -> List[Document]:
            if self.search_type == "similarity":
                docs = self.vectorstore.similarity_search(query, filter=filter, **self.search_kwargs)
            ...
            return docs
    

    If anyone has a more elegant solution, please share.