Search code examples
pythonpydantic

How to perform Serialization of a pydantic model with polymorphism?


I tried to serialize a pydantic model with an attribute that can be of class of multiple subclasses of a base class. However with a naive implementation the subclasses are serialized to the baseclass.

After reading this issue I wrote the following code but without any success:

from typing import Dict, Literal, Union

from pydantic import BaseModel, Field, RootModel


class NodeBase(BaseModel):
    id: str


class StartNode(NodeBase):
    type: Literal["start"] = "start"


class EndNode(NodeBase):
    type: Literal["end"] = "end"


class LLMNode(NodeBase):
    type: Literal["llm"] = "llm"
    name: str = Field(default_factory=lambda: id)
    purpose: str
    prompt: str
    model: Literal[
        "gpt-4o", "gpt4-turbo", "gpt-4", "gpt-3.5-turbo", "azure-gpt-3.5-turbo"
    ]


class NodeModel(RootModel):
    root: Union[StartNode, EndNode, LLMNode]


class Graph(BaseModel):
    nodes: Dict[str, NodeModel] = Field(default_factory=dict)

    def add_node(self, node: Union[StartNode, EndNode, LLMNode]) -> None:
        self.nodes[node.id] = NodeModel(root=node)


start_node = StartNode(id="start", type="start")
llm_node = LLMNode(id="llm", type="llm", purpose="test", prompt="test", model="gpt-4o")
end_node = EndNode(id="end", type="end")

# ========= Node tests =========
start_node_dict = start_node.model_dump()
llm_node_dict = llm_node.model_dump()
end_node_dict = end_node.model_dump()
# Is it possible to use model_validate with the base class?
start_node_from_dict = NodeBase.model_validate(start_node_dict)
llm_node_from_dict = NodeBase.model_validate(llm_node_dict)
end_node_from_dict = NodeBase.model_validate(end_node_dict)


assert start_node == start_node_from_dict
assert llm_node == llm_node_from_dict
assert end_node == end_node_from_dict


# ========= Graph tests =========
g = Graph()
g.add_node(start_node)
g.add_node(llm_node)
g.add_node(end_node)

g_dict = g.model_dump()
g_from_dict = Graph.model_validate(g_dict)
assert g == g_from_dict

give the following errors :

UserWarning: Pydantic serializer warnings:
  Expected `str` but got `builtin_function_or_method` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(
Traceback (most recent call last):
  File "file.py", line 53, in <module>
    assert start_node == start_node_from_dict
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

I would like to be able to dump a Graph whether the added nodes are subclasses of Node such as StartNode or LLMNode and be able to deserialize a graph back where all nodes have the right types. In addition it would be great if I could also deserialize a subclass of Node without knowing which type directly with a NodeBase.model_validate(subclass_of_NodeBase_that_I_dont_know_the_type_of)


Solution

  • Thank you for your feedbacks, here is the solution I used for my problem.

    # Create the NodeTypes union from the node types list
    NodeTypes = Union[tuple(node_types)] # shouldn't contain NodeBase
    
    
    class NodeModel(RootModel):
        root: NodeTypes
    
        @model_validator(mode="after")
        @classmethod
        def get_root(cls, obj):
            if hasattr(obj, "root"):
                return obj.root
            return obj
    

    And have a different way of adding nodes to the graph

    def add_node(self: Self, node: NodeBase) -> None:
        """Add a node to the graph.
        :param node: An instance of the Node class
        """
        self.nodes[node.id] = node