Search code examples
pythonlangchainlarge-language-modelretrieval-augmented-generation

langchain: How to view the context my retriever used when invoke


I am trying to make a private llm with RAG capabilities. I successfully followed a few tutorials and made one. But I wish to view the context the MultiVectorRetriever retriever used when langchain invokes my query.

This is my code:

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_community.chat_models import ChatOllama
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_core.runnables import RunnablePassthrough
from PIL import Image
import io
import os
import uuid
import json
import base64

def convert_bytes_to_base64(image_bytes):
    encoded_string=  base64.b64encode(image_bytes).decode("utf-8")
    return "data:image/jpeg;base64," + encoded_string

#Load Retriever

path="./vectorstore/pdf_test_file.pdf"

#Load from JSON files
texts = json.load(open(os.path.join(path, "json", "texts.json")))
text_summaries = json.load(open(os.path.join(path, "json", "text_summaries.json")))
tables = json.load(open(os.path.join(path, "json", "tables.json")))
table_summaries = json.load(open(os.path.join(path, "json", "table_summaries.json")))
img_summaries = json.load(open(os.path.join(path, "json", "img_summaries.json")))

#Load from figures
images_base64_list = []
for image in (os.listdir(os.path.join(path, "figures"))):
    
    img = Image.open(os.path.join(path, "figures",image))
    buffered = io.BytesIO()
    img.save(buffered,format="png")
    image_base64 = convert_bytes_to_base64(buffered.getvalue())
    #Warning: this section of the code does not support external IDEs like spyder and will break. Run it loccally in the native terminal
    images_base64_list.append(image_base64)


#Add to vectorstore

# The vectorstore to use to index the child chunks
vectorstore = Chroma(
    collection_name="summaries", embedding_function=GPT4AllEmbeddings()
)

# The storage layer for the parent documents
store = InMemoryStore()  # <- Can we extend this to images
id_key = "doc_id"

# The retriever (empty to start)
retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
)

# Add texts
doc_ids = [str(uuid.uuid4()) for _ in texts]
summary_texts = [
    Document(page_content=s, metadata={id_key: doc_ids[i]})
    for i, s in enumerate(text_summaries)
]
retriever.vectorstore.add_documents(summary_texts)
retriever.docstore.mset(list(zip(doc_ids, texts)))

# Add tables
table_ids = [str(uuid.uuid4()) for _ in tables]
summary_tables = [
    Document(page_content=s, metadata={id_key: table_ids[i]})
    for i, s in enumerate(table_summaries)
]
retriever.vectorstore.add_documents(summary_tables)
retriever.docstore.mset(list(zip(table_ids, tables)))

# Add images
img_ids = [str(uuid.uuid4()) for _ in img_summaries]
summary_img = [
    Document(page_content=s, metadata={id_key: img_ids[i]})
    for i, s in enumerate(img_summaries)
]
retriever.vectorstore.add_documents(summary_img)
retriever.docstore.mset(
    list(zip(img_ids, img_summaries))
)  # Store the image summary as the raw document


img_summaries_ids_and_images_base64=[]
count=0
for img in images_base64_list:
    new_summary = [img_ids[count],img]
    img_summaries_ids_and_images_base64.append(new_summary)
    count+=1



# Check Response

# Question Example: "What is the issues plagueing the acres?"

"""
Testing Retrival

print("\nTesting Retrival: \n")
prompt = "Images / figures with playful and creative examples"
responce = retriever.get_relevant_documents(prompt)[0]
print(responce)

"""

"""
retriever.vectorstore.similarity_search("What is the issues plagueing the acres? show any relevant tables",k=10)
"""

# Prompt template
template = """Answer the question based only on the following context, which can include text, tables and images/figures:
{context}
Question: {question}
"""

prompt = ChatPromptTemplate.from_template(template)

# Multi-modal LLM
# model = LLaVA
model = ChatOllama(model="custom-mistral")

# RAG pipeline
chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | model
    | StrOutputParser()
)

print("\n\n\nTesting Responce: \n")

print(chain.invoke(
    "What is the issues plagueing the acres? show any relevant tables"
))

The output will look something like this:


Testing Responce:

In the provided text, the main issue with acres is related to wildfires and their impact on various lands and properties. The text discusses the number of fires, acreage burned, and the level of destruction caused by wildfires in the United States from 2018 to 2022. It also highlights that most wildfires are human-caused (89% of the average number of wildfires from 2018 to 2022) and that fires caused by lightning tend to be slightly larger and burn more acreage than those caused by humans.

Here's the table provided in the text, which shows the number of fires and acres burned on federal lands (by different organizations), other non-federal lands, and total:

| Year | Number of Fires (thousands) | Acres Burned (millions) |
|------|-----------------------------|--------------------------|
| 2018 | 58.1                        | 8.8                      |
| 2019 | 58.1                        | 4.7                      |
| 2020 | 58.1                        | 10.1                     |
| 2021 | 58.1                        | 10.1                     |
| 2022 | 58.1                        | 3.6                      |

The table also breaks down the acreage burned by federal lands (DOI and FS) and other non-federal lands, as well as showing the total acreage burned each year.<|im_end|>

From the RAG pipline i wish to print out the the context used from the retriever which stores tons of vector embeddings. i wish to know which ones it uses for the query. something like :

chain.invoke("What is the issues plagueing the acres? show any relevant tables").get_context_used()

i know there are functions like

retriever.get_relevant_documents(prompt) 

and

retriever.vectorstore.similarity_search(prompt) 

which provides the most relevant context to the query but I'm unsure whether the invoke function pulls the same context with the other 2 functions.

the Retriver Im using from Langchain is the MultiVectorRetriever


Solution

  • You can tap into langchains with a RunnableLambda and print the state passed from the retriever to the prompt

    from langchain_core.runnables import RunnableLambda
    
    def inspect(state):
        """Print the state passed between Runnables in a langchain and pass it on"""
        print(state)
        return state
    
    # RAG pipeline
    chain = (
        {"context": retriever, "question": RunnablePassthrough()}
        | RunnableLambda(inspect)  # Add the inspector here to print the intermediate results
        | prompt
        | model
        | StrOutputParser()
    )