Search code examples
pythonlangchainamazon-bedrockclaude

Langchain workaround for with_structured_output using ChatBedrock


I'm working with the langchain library to implement a document analysis application. Especifically I want to use the routing technique described in this documentation. i wanted to follow along the example, but my environment is restricted to AWS, and I am using ChatBedrock instead of ChatOpenAI due to limitations with my deployment.

According to this overview the with_structured_output method, which I need, is not (yet) implemented for models on AWS Bedrock, which is why I am looking for a workaround or any method to replicate this functionality.

The key functionality I am looking for is shown in this example:

from typing import List
from typing import Literal

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI



class RouteQuery(BaseModel):
    """Route a user query to the most relevant datasource."""

    datasources: List[Literal["python_docs", "js_docs", "golang_docs"]] = Field(
        ...,
        description="Given a user question choose which datasources would be most relevant for answering their question",
    )

system = """You are an expert at routing a user question to the appropriate data source.

Based on the programming language the question is referring to, route it to the relevant data source."""
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{question}"),
    ]
)

llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = llm.with_structured_output(RouteQuery)
router = prompt | structured_llm
router.invoke(
    {
        "question": "is there feature parity between the Python and JS implementations of OpenAI chat models"
    }
)

The output would be:

RouteQuery(datasources=['python_docs', 'js_docs'])

The most important fact for me is that it just selects items from the list without any additional overhead, which makes it possible to setup the right follow up questions.

Did anyone find a workaround how to resolve this issue?


Solution

  • I found a solution in these two blog posts: here and here. The key is to use the instructor package, which is a wrapper around pydantic. This means langchain is not necessary.

    Here is an example based on the blog posts:

    from typing import List
    import instructor
    from anthropic import AnthropicBedrock
    from loguru import logger
    from pydantic import BaseModel
    import enum
    
    class User(BaseModel):
        name: str
        age: int
    
    class MultiLabels(str, enum.Enum):
        TECH_ISSUE = "tech_issue"
        BILLING = "billing"
        GENERAL_QUERY = "general_query"
    
    class MultiClassPrediction(BaseModel):
        """
        Class for a multi-class label prediction.
        """
        class_labels: List[MultiLabels]
    
    if __name__ == "__main__":
        # Initialize the instructor client with AnthropicBedrock configuration
        client = instructor.from_anthropic(
            AnthropicBedrock(
                aws_region="eu-central-1",
            )
        )
    
        logger.info("Hello World Example")
    
        # Create a message and extract user data
        resp = client.messages.create(
            model="anthropic.claude-instant-v1",
            max_tokens=1024,
            messages=[
                {
                    "role": "user",
                    "content": "Extract Jason is 25 years old.",
                }
            ],
            response_model=User,
        )
    
        print(resp)
        logger.info("Classification Example")
    
        # Classify a support ticket
        text = "My account is locked and I can't access my billing info."
    
        _class = client.chat.completions.create(
            model="anthropic.claude-instant-v1",
            max_tokens=1024,
            response_model=MultiClassPrediction,
            messages=[
                {
                    "role": "user",
                    "content": f"Classify the following support ticket: {text}",
                },
            ],
        )
    
        print(_class)