Search code examples
pythonjupyter-notebookagentpy-langchainlanggraph

"agent_node() got multiple values for argument 'agent'" when extract langchain example code from notebook


I'm running the example LangChain/LangGraph code for "Basic Multi-agent Collaboration." I got the example from here (Github). There is also a blogpost/video.

After configuring my local virtual environment and copying and pasting the notebook, everything is working fine.

In the same virtual environment, I'm adapting the code to avoid to use the Jupyter Notebook. To do so, I decoupled the code into some classes.

Original code (working):


# Helper function to create a node for a given agent
def agent_node(state, agent, name):
    result = agent.invoke(state)
    # We convert the agent output into a format that is suitable to append to the global state
    if isinstance(result, FunctionMessage):
        pass
    else:
        result = HumanMessage(**result.dict(exclude={"type", "name"}), name=name)
    return {
        "messages": [result],
        # Since we have a strict workflow, we can
        # track the sender so we know who to pass to next.
        "sender": name,
    }


llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=open_ai_key)

# Research agent and node
research_agent = create_agent(
    llm, 
    [search], 
    system_message="You should provide accurate data for the chart generator to use.",
)
research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")

# Chart Generator
chart_agent = create_agent(
    llm,
    [python_repl],
    system_message="Any charts you display will be visible by the user.",
)
chart_node = functools.partial(agent_node, agent=chart_agent, name="Chart Generator")

#...some other cells

for s in graph.stream(
    {
        "messages": [
            HumanMessage(
                content="Fetch the UK's GDP over the past 5 years,"
                " then draw a line graph of it."
                " Once you code it up, finish."
            )
        ],
    },
    # Maximum number of steps to take in the graph
    {"recursion_limit": 100},
):
    print(s)
    print("----")

My adapted version (not working):

nodes.py

class Nodes:

    def __init__(self, llm, tools):
        self.llm = llm
        self.tools = tools

    def agent_node(state, agent, name):
       
        result = agent.invoke(state)
        if isinstance(result, FunctionMessage):
            pass
        else:
            result = HumanMessage(**result.dict(exclude={"type", "name"}), name=name)
        return {
            "messages": [result],
            # track the sender so we know who to pass to next.
            "sender": name,
        }
    
    def research_agent_node(self):
        
        research_agent = create_agent(
            self.llm, 
            [self.tools[0]], 
            system_message="You should provide accurate data for the chart generator to use.",
        )
        return functools.partial(self.agent_node, agent=research_agent, name="Researcher")
    
    def chart_generator_node(self):
        
        chart_agent = create_agent(
            self.llm,
            [self.tools[0]],
            system_message="Any charts you display will be visible by the user.",
        )
        return functools.partial(self.agent_node, agent=chart_agent, name="Chart Generator")

main.py

# Get the agent and tool nodes
nodes = Nodes(llm, tools)
research_agent_node = nodes.research_agent_node()
chart_agent_node = nodes.chart_generator_node()


#...more code

for s in graph.stream(
        {
            "messages": [
                HumanMessage(
                    content="Fetch the UK's GDP over the past 5 years,"
                    " then draw a line graph of it."
                    " Once you code it up, finish."
                )
            ],
        },
        # Maximum number of steps to take in the graph
        {"recursion_limit": 100},
    ):
    print(s)
    print("----")

The error message is TypeError: Nodes.agent_node() got multiple values for argument 'agent' and it's pretty straightforward to understand, but really hard to debug. Maybe I misunderstood how to return the functools.partial, but I don't know how to verify it.


Solution

  • For those facing similar problems, the solution is to convert the agent_node function to be static.

    nodes.py

        @staticmethod
        def agent_node(state, agent, name):
           
            result = agent.invoke(state)
            if isinstance(result, FunctionMessage):
                pass
            else:
                result = HumanMessage(**result.dict(exclude={"type", "name"}), name=name)
            return {
                "messages": [result],
                # track the sender so we know who to pass to next.
                "sender": name,
            }
    
        # ... same code
    
        return functools.partial(Node.agent_node, agent=research_agent, name="Researcher")
    
    
    

    The problem happened because I used the agent_node method as an instance method, which automatically takes self as the first argument. This messed up the order of the arguments (self was being used as state, and the second parameter, namely state, caused the error message of the duplicate parameter).