Search code examples
pythonpython-asynciolangchainlarge-language-model

How to use langchain RetrievalQA with asyncio?


I want to parallelize RetrievalQA with asyncio but I am unable to figure out how.

This is how my code works serially:

import langchain
from langchain.chat_models import ChatOpenAI
from langchain import PromptTemplate, LLMChain
from langchain.chains import RetrievalQA
from langchain.vectorstores import FAISS
from langchain.schema.vectorstore import VectorStoreRetriever
import asyncio
import nest_asyncio

retriever = VectorStoreRetriever(vectorstore=FAISS(...))

chat = ChatOpenAI(model="gpt-3.5-turbo-16k", temperature=0.7)

qa_chain = RetrievalQA.from_llm(chat, retriever= retriever
                                                 #,memory=memory
                                                 , return_source_documents=True
                                                 )

queries = ['query1', 'query2', 'query3']
data_to_append = []

for query in queries :

    vectordbkwargs = {"search_distance": 0.9}
    result = qa_chain({"query": query, "vectordbkwargs": vectordbkwargs})

    data_to_append.append({"Query": query, "Source_Documents": result["source_documents"], "Generated_Text": result["result"]})

Here was my attempt to parallelize it with asyncio but RetrievalQA doesn't seem to work async:

import langchain
from langchain.chat_models import ChatOpenAI
from langchain import PromptTemplate, LLMChain
from langchain.chains import RetrievalQA
from langchain.vectorstores import FAISS
from langchain.schema.vectorstore import VectorStoreRetriever
import asyncio
import nest_asyncio

retriever = VectorStoreRetriever(vectorstore=FAISS(...))

chat = ChatOpenAI(model="gpt-3.5-turbo-16k", temperature=0.7)


qa_chain = RetrievalQA.from_llm(chat, retriever= retriever
                                                 , return_source_documents=True
                                                 )

queries = ['query1', 'query2', 'query3']
data_to_append = []



async def process_query(query):

        vectordbkwargs = {"search_distance": 0.9}
        result = await qa_chain({"query": query, "vectordbkwargs": vectordbkwargs})
        data_to_append.append({"Query": query, "Source_Documents": result["source_documents"], "Generated_Text": result["result"]})


async def main():

    tasks = []

    for query in queries: # Iterate all rows
        task = process_query(query)
        tasks.append(task)

    await asyncio.gather(*tasks)

if __name__ == "__main__":
    nest_asyncio.apply()
    asyncio.run(main())

Any help would be greatly appreciated.


Solution

  • To make it work async, the solution I found was to use RetrievalQA._acall instead of just using RetrievalQA.

    Here is a sample code snippet which has only one minor change:

    import langchain
    from langchain.chat_models import ChatOpenAI
    from langchain import PromptTemplate, LLMChain
    from langchain.chains import RetrievalQA
    from langchain.vectorstores import FAISS
    from langchain.schema.vectorstore import VectorStoreRetriever
    import asyncio
    import nest_asyncio
    
    retriever = VectorStoreRetriever(vectorstore=FAISS(...))
    
    chat = ChatOpenAI(model="gpt-3.5-turbo-16k", temperature=0.7)
    
    
    qa_chain = RetrievalQA.from_llm(chat, retriever= retriever
                                                     , return_source_documents=True
                                                     )
    
    queries = ['query1', 'query2', 'query3']
    data_to_append = []
    
    
    
    async def process_query(query):
    
            vectordbkwargs = {"search_distance": 0.9}
            # Change qa_chain to qa_chain._acall
            result = await qa_chain._acall({'query': query, "vectordbkwargs": vectordbkwargs})
            data_to_append.append({"Query": query, "Source_Documents": result["source_documents"], "Generated_Text": result["result"]})
    
    
    async def main():
    
        tasks = []
    
        for query in queries: # Iterate all rows
            task = process_query(query)
            tasks.append(task)
    
        await asyncio.gather(*tasks)
    
    if __name__ == "__main__":
        nest_asyncio.apply()
        asyncio.run(main())