Search code examples
jsonamazon-sagemakerlangchainlarge-language-modelllama

Sagemaker and LangChain: ValueError when calling InvokeEndpoint operation for Llama 2 model


I am trying to deploy a Llama 2 model for text generation inference using Sagemaker and LangChain. I am writing code in Notebook instances and deploying SageMaker instances from the code. I followed the documentation from https://python.langchain.com/docs/integrations/llms/sagemaker. I used the following code to create a chain for question answering:

from langchain.docstore.document import Document
example_doc_1 = """
Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.
Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.
Therefore, Peter stayed with her at the hospital for 3 days without leaving.
"""

docs = [
    Document(
        page_content=example_doc_1,
    )
]

from typing import Dict

from langchain import PromptTemplate, SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.chains.question_answering import load_qa_chain
import json

query = """How long was Elizabeth hospitalized?
"""

prompt_template = """Use the following pieces of context to answer the question at the end.

{context}

Question: {question}
Answer:"""
PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)


class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({prompt: prompt, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]


content_handler = ContentHandler()

chain = load_qa_chain(
    llm=SagemakerEndpoint(
        endpoint_name="XYZ",
        credentials_profile_name="XYZ",
        region_name="XYZ",
        model_kwargs={"temperature": 1e-10},
        content_handler=content_handler,
    ),
    prompt=PROMPT,
)

chain({"input_documents": docs, "question": query}, return_only_outputs=True)

But I got an error

ValueError: Error raised by inference endpoint: 
An error occurred (ModelError) when calling the InvokeEndpoint operation: 
Received client error (422) from primary with message 
"Failed to deserialize the JSON body into the target type: missing field `inputs` at line 1 column 966".

In multiple tutorials there isn't any inputs field. I have no idea if they updated the documentation, which I have been referring to but can't resolve this problem.

My question is:

  • Why am I getting this error and how can I fix it?
  • What am I missing in my code or configuration? Any help or guidance would be appreciated. Thanks in advance.

Solution

  • Looks like it is a known issues with langchain documentation, @sigvamo mentioned this error can be workaround by updating ContentHandler to include inputs in its transform_input method

    from typing import Dict, List
    from langchain.embeddings import SagemakerEndpointEmbeddings
    from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
    import json
    
    class ContentHandler(EmbeddingsContentHandler):
       content_type = "application/json"
       accepts = "application/json"
    
       def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
           input_str = json.dumps({"text_inputs": inputs, **model_kwargs})
           return input_str.encode("utf-8")
    
       def transform_output(self, output: bytes) -> List[List[float]]:
           response_json = json.loads(output.read().decode("utf-8"))
           return response_json["embedding"]