Search code examples
python-3.xlangchainragchainlitretrievalqa

Issue on memory ram with langchain RetrievalQA


Hi im doing a RAG system with multiple vector databases using chainlit, langchain and FAISS. Few days ago i saw that RAG was using a lot of memory ram like 10gb, so i want to fix it, but i don't know if the langchain has a close method or something like that then i can use to close process of retrievalQA.

I going to share my code if someone know how can i fix this issue pls tell me.

from typing import Dict, Optional
from langchain import PromptTemplate
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
import chainlit as cl
from langchain_groq import ChatGroq
from dotenv import load_dotenv
import os

from google.oauth2 import id_token
from google.auth.transport import requests
from utils.Retrival import Retrival; 



load_dotenv()
temp = 0
groq_api = os.getenv("GROQ_KEY")



db_faiss = ""
GROQ_MODEL = "llama-3.1-70b-versatile"

#DB_FAISS_PATH = "./vectorestore_sinHU"
custom_prompt_template ="""


Context: {context}
Question: {question}

Recuerda no puedes creear una respuesta, solamnete te pudes basar en el contexto.
la mejor respuesta:
"""




def set_custom_prompt():
    prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context','question'])
    return prompt

def retrieval_qa_chain(llm, prompt, db, kw = 80):
    qa_chain = RetrievalQA.from_chain_type(llm=llm,
                                       chain_type='stuff',
                                       retriever=db.as_retriever(search_kwargs={'k': kw}),
                                       return_source_documents=True,
                                       chain_type_kwargs={'prompt': prompt}
                                       )
    return qa_chain

def qa_bot(db,kw):
    embeddings = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1",
                                       model_kwargs={'device': 'cpu'})
    db = FAISS.load_local(db, embeddings,allow_dangerous_deserialization=True)
    llm = ChatGroq(groq_api_key = groq_api,temperature=temp, model_name=GROQ_MODEL)
    qa_prompt = set_custom_prompt()
    qa = retrieval_qa_chain(llm, qa_prompt, db,kw)

    return qa  


retrivalAux = Retrival() 

@cl.on_chat_start
async def start():
    global retrivalAux
    
    chat_profile = cl.user_session.get('chat_profile')

    msg = cl.Message(content="Cargando chatbot por favor espere...")
    await msg.send()

    match chat_profile:
        case "chat1":           
            msg.content = "Hola."
            db_faiss = ""
            if retrivalAux.getNombre() != "chat1" or not retrivalAux:
                cl.user_session.set("chain", None)
                if retrivalAux:
                    retrivalAux.destroy()
                retrivalAux.setNombre(nombre="chat1")                retrivalAux.setRetrival(retrivalQA=qa_bot(db="./vectorstores/chat1",kw=20))

            cl.user_session.set("chain", retrivalAux.getRetriever())
            
        case "chat2":

            msg.content = "Hola"
            if retrivalAux.getNombre() != "chat2" or not retrivalAux:
                cl.user_session.set("chain", None)
                if retrivalAux:
                    retrivalAux.destroy()
                retrivalAux.setNombre(nombre="chat2")
                retrivalAux.setRetrival(retrivalQA=qa_bot(db=db_faiss,kw=20))
            cl.user_session.set("chain", retrivalAux.getRetriever())
            
        case "chat3":
            cl.user_session.set("chain", None)
            msg.content = "Hola"
            db_faiss =""
            if retrivalAux.getNombre() != "chat3" or not retrivalAux:
                cl.user_session.set("chain", None)
                if retrivalAux:
                    retrivalAux.destroy()
                retrivalAux.setNombre(nombre=cimabot_SISCRED_ALU.name)
                retrivalAux.setRetrival(retrivalQA=qa_bot(db=db_faiss,kw=20))
            cl.user_session.set("chain", retrivalAux.getRetriever())
            
            
        case "chat4":
            
            msg.content = "Hola"
            db_faiss ="./vectorstores/chat4"
            if retrivalAux.getNombre() != "chat4" or not retrivalAux:
                cl.user_session.set("chain", None)
                if retrivalAux:
                    retrivalAux.destroy()
                retrivalAux.setNombre(nombre="chat4")
                retrivalAux.setRetrival(retrivalQA=qa_bot(db=db_faiss,kw=20))
            cl.user_session.set("chain", retrivalAux.getRetriever())
        case "chat5":
            cl.user_session.set("chain", None)
            msg.content = "hi"
            #chain =  qa_bot(db="./vectorstores/chat5",kw=20)
            db_faiss ="./vectorstores/SIMA/alumno"
            if retrivalAux.getNombre() != chat5 or not retrivalAux:
                cl.user_session.set("chain", None)
                if retrivalAux:
                    retrivalAux.destroy()
                retrivalAux.setNombre(nombre=chat5)
                retrivalAux.setRetrival(retrivalQA=qa_bot(db=db_faiss,kw=20))
            cl.user_session.set("chain", retrivalAux.getRetriever())
        case "chat6":

            msg.content = "hi"
            db_faiss = "./vectorstores/chat6"
            if retrivalAux.getNombre() != "chat6" or not retrivalAux:
                cl.user_session.set("chain", None)
                if retrivalAux:
                    retrivalAux.destroy()
                retrivalAux.setNombre(nombre="chat6")
                retrivalAux.setRetrival(retrivalQA=qa_bot(db=db_faiss,kw=20))
            cl.user_session.set("chain", retrivalAux.getRetriever())
        case "chat7":

            msg.content = "Hola"
            db_faiss = "./vectorstores/chat7"

            if retrivalAux.getNombre() != "chat7" or not retrivalAux:
                cl.user_session.set("chain", None)
                if retrivalAux:
                    retrivalAux.destroy()
                retrivalAux.setNombre(nombre="chat7")
                retrivalAux.setRetrival(retrivalQA=qa_bot(db=db_faiss,kw=20))
            cl.user_session.set("chain", retrivalAux.getRetriever())
    await msg.update()

    

@cl.on_message
async def main(message: cl.Message):

    chain = cl.user_session.get("chain")

    cb = cl.AsyncLangchainCallbackHandler(
        stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"]
    )
    cb.answer_reached = True
    
    res = await chain.acall(message.content, callbacks=[cb])
    answer = res["result"]

    
    await cl.Message(content=answer).send()
    
@cl.on_chat_end
def end():
    print("--------------adios--------")
    global retrivalAux
    cl.user_session.set("chain", None)
    
    retrivalAux.destroy()
    
    
if __name__ == "__main__":
    from chainlit.cli import run_chainlit
    run_chainlit(__file__)

Retrival.py



class Retrival:
    
    def __init__(self, nombre=None,retrivalQA = None):
        
        self.retriever = retrivalQA;
        self.nombre = nombre;
        print(nombre)
    
    def getRetriever(self):
        return self.retriever
    
    def getNombre(self):
        return self.nombre
    
    def setNombre(self, nombre=None):
        del self.nombre
        self.nombre = nombre
    
    def setRetrival(self,retrivalQA = None):
        self.retriever = retrivalQA
    
    
    
    def destroy(self):
        if self.retriever:
            print("Memory--------------")
            print(self.retriever.memory)
        
        
        del self.retriever
        del self.nombre
        
        self.nombre = None
        self.retriever = None
        print(self.retriever)

I want to a solution for delete the process of RetrievalQA after using it.


Solution

  • i fixed this issue, i moved out the embedding of the qa_bot function, then i fixed, because when i change chat profile it creates another embedding but all chat profiles use the same embedding XD so i moved out and i fixed it.