Search code examples
pythonpydantic

How to use subclasses in pydantic nested model?


In pydantic, I want to create a model that includes subclasses of another model. Below, find an example.

from pydantic import BaseModel, Field


class Action(BaseModel):
    name: str


class LogAction(Action):
    log_level: str
    timestamp: str


class Alert(BaseModel):
    id: int
    message: str
    action: Action


alert1 = Alert(
    id=1,
    message="Alert Message",
    action=LogAction(name="Error Log", log_level="ERROR", timestamp="2024-04-20T10:45:00"),
)
print(alert1)
print(alert1.model_dump())

The alert includes the attributes specific to LogAction, but on model_dump, only the attribute of the base class Action is included. Of course I can specify action: Action | LogAction but I want action to accept any class derived from Action. Still, all attributes of the subclass should be dumped/ serialized. Is there a clever way to do this?


Solution

  • You can solve it by using generic with type var bounded to Action.

    The disadvantage is that in json schema you will see only Action.

    from typing import Generic, TypeVar
    
    from pydantic import BaseModel, Field, ValidationError
    
    
    class Action(BaseModel):
        name: str
    
    
    class LogAction(Action):
        log_level: str
        timestamp: str
    
    
    class AnotherAction(Action):
        something: str
    
    
    T = TypeVar("T", bound=Action)
    
    
    class Alert(BaseModel, Generic[T]):
        id: int
        message: str
        action: T
    
    
    alert1 = Alert(
        id=1,
        message="Alert Message",
        action=LogAction(name="Error Log", log_level="ERROR", timestamp="2024-04-20T10:45:00"),
    )
    print(alert1)
    print(alert1.model_dump())
    
    alert2 = Alert(
        id=1,
        message="Another Message",
        action=AnotherAction(name="another", something="123"),
    )
    print(alert2)
    print(alert2.model_dump())
    
    
    
    class AnotherAction2(BaseModel):  # Not inherited from Action
        something: str
    
    try:
        alert3 = Alert(
            id=1,
            message="Another Message",
            action=AnotherAction2(something="123"),
        )
    except ValidationError:
        print("alert3 validation error")