I tried to serialize a pydantic model with an attribute that can be of class of multiple subclasses of a base class. However with a naive implementation the subclasses are serialized to the baseclass.
After reading this issue I wrote the following code but without any success:
from typing import Dict, Literal, Union
from pydantic import BaseModel, Field, RootModel
class NodeBase(BaseModel):
id: str
class StartNode(NodeBase):
type: Literal["start"] = "start"
class EndNode(NodeBase):
type: Literal["end"] = "end"
class LLMNode(NodeBase):
type: Literal["llm"] = "llm"
name: str = Field(default_factory=lambda: id)
purpose: str
prompt: str
model: Literal[
"gpt-4o", "gpt4-turbo", "gpt-4", "gpt-3.5-turbo", "azure-gpt-3.5-turbo"
]
class NodeModel(RootModel):
root: Union[StartNode, EndNode, LLMNode]
class Graph(BaseModel):
nodes: Dict[str, NodeModel] = Field(default_factory=dict)
def add_node(self, node: Union[StartNode, EndNode, LLMNode]) -> None:
self.nodes[node.id] = NodeModel(root=node)
start_node = StartNode(id="start", type="start")
llm_node = LLMNode(id="llm", type="llm", purpose="test", prompt="test", model="gpt-4o")
end_node = EndNode(id="end", type="end")
# ========= Node tests =========
start_node_dict = start_node.model_dump()
llm_node_dict = llm_node.model_dump()
end_node_dict = end_node.model_dump()
# Is it possible to use model_validate with the base class?
start_node_from_dict = NodeBase.model_validate(start_node_dict)
llm_node_from_dict = NodeBase.model_validate(llm_node_dict)
end_node_from_dict = NodeBase.model_validate(end_node_dict)
assert start_node == start_node_from_dict
assert llm_node == llm_node_from_dict
assert end_node == end_node_from_dict
# ========= Graph tests =========
g = Graph()
g.add_node(start_node)
g.add_node(llm_node)
g.add_node(end_node)
g_dict = g.model_dump()
g_from_dict = Graph.model_validate(g_dict)
assert g == g_from_dict
give the following errors :
UserWarning: Pydantic serializer warnings:
Expected `str` but got `builtin_function_or_method` - serialized value may not be as expected
return self.__pydantic_serializer__.to_python(
Traceback (most recent call last):
File "file.py", line 53, in <module>
assert start_node == start_node_from_dict
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
I would like to be able to dump a Graph whether the added nodes are subclasses of Node such as StartNode or LLMNode and be able to deserialize a graph back where all nodes have the right types. In addition it would be great if I could also deserialize a subclass of Node without knowing which type directly with a NodeBase.model_validate(subclass_of_NodeBase_that_I_dont_know_the_type_of)
Thank you for your feedbacks, here is the solution I used for my problem.
# Create the NodeTypes union from the node types list
NodeTypes = Union[tuple(node_types)] # shouldn't contain NodeBase
class NodeModel(RootModel):
root: NodeTypes
@model_validator(mode="after")
@classmethod
def get_root(cls, obj):
if hasattr(obj, "root"):
return obj.root
return obj
And have a different way of adding nodes to the graph
def add_node(self: Self, node: NodeBase) -> None:
"""Add a node to the graph.
:param node: An instance of the Node class
"""
self.nodes[node.id] = node