I'm trying to add metadata filtering of the underlying vector store (chroma
).
db = Chroma.from_documents(texts, embeddings)
It works like this:
qa = ConversationalRetrievalChain.from_llm(
OpenAI(openai_api_key=get_random_key(OPENAI_API_KEY_POOL), cache=True, temperature=0),
VectorStoreRetriever(vectorstore=db, search_kwargs={"filter": {"source": "data/my.pdf"}}),
verbose=True,
return_source_documents=True)
result = qa({"question": query, "chat_history": []})
But this would imply creating a separate chain for each document which seems weird. However, when I try to pass the filter to the existing chain, it doesn't seem to have any effect, it returns results for all the documents in the db.
qa = ConversationalRetrievalChain.from_llm(
OpenAI(openai_api_key=get_random_key(OPENAI_API_KEY_POOL), cache=True, temperature=0),
VectorStoreRetriever(vectorstore=db),
verbose=True,
return_source_documents=True)
filter = {'source': 'my.pdf'}
result = qa({"question": query, "chat_history": [], "filter": filter})
Am I missing something? Or it really won't work without extending the existing classes/modifying source code of langchain
?
I ended up extending both classes I used to pass the filter:
class ConversationalRetrievalChainPassArgs(ConversationalRetrievalChain):
def _get_docs(self, question: str, inputs: Dict[str, Any], *,
run_manager: CallbackManagerForChainRun) -> List[Document]:
"""Get docs."""
docs = self.retriever._get_relevant_documents(
question, inputs['filter']
)
return self._reduce_tokens_below_limit(docs)
class VectorStoreRetrieverWithFiltering(VectorStoreRetriever):
def _get_relevant_documents(self, query: str, filter: dict) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, filter=filter, **self.search_kwargs)
...
return docs
If anyone has a more elegant solution, please share.