Search code examples
pythonpydanticpydantic-v2

Remove field from all nested pydantic models


I would like to be able to define a field on a model that can be removed in all nested occurrences by calling model_dump in whatever way. See below example for an attempt:

# Trying nested properties
from typing import Optional

from pydantic import BaseModel, Field, field_validator


class BaseModel2(BaseModel):
    class_name: Optional[str] = Field(None, validate_default=True)

    @field_validator("class_name")
    @classmethod
    def set_class_name(cls, v):
        if v is None:
            return cls.__name__
        else:
            raise ValueError("class_name must not be set")


class Level3(BaseModel2):
    whatever: int = 10


class Level2(BaseModel2):
    whatever: int
    level3: Level3


class Level1(BaseModel2):
    whenever: Optional[float] = 1.1
    level2: Level2


m = Level1(whenever=3.14, level2=Level2(whatever=123, level3=Level3(whatever=20)))
print(m.model_dump(exclude={"class_name": True, "__all__": {"class_name"}}))
>>> {'whenever': 3.14, 'level2': {'whatever': 123, 'level3': {'class_name': 'Level3', 'whatever': 20}}}

What I would expect is that exclude allows me to exclude all class_name occurences, so far I haven't managed.

Ultimate aim

If the above is not possible, then maybe something else is. My ultimate aim is to allow a context specific model dump.

But crucially I do not want to change serialisation of a single field, I want to add some information (the class name of the model cls.__name__) to the model serialisation. All subsequent models would inherit from it and also be able to dump that information. See this discussion how I could achieve this in V1 due to dict being actually recursive! : https://github.com/pydantic/pydantic/discussions/11078

Edit

For instance, I could imagine working with the custom serializer, but then I would run into recursion problems OR (commented out) with overwriting the model_dump functionality:

class BaseModel2(BaseModel):
    # def model_dump(self, **kwargs):
    #     _dict = super().model_dump(serialize_as_any=True, **kwargs)
    #     # _dict["__vizro_model__"] = self.__class__.__name__
    #     return _dict

    @model_serializer
    def ser_model(self, info: SerializationInfo) -> Dict[str, Any]:
        print("CONTEXT", info)
        _dict = {}  # self.model_dump()
        _dict["__vizro_model__"] = self.__class__.__name__
        return _dict

Solution

  • You need to use a model_serializer in wrap mode. This gives you access to a function handle that you can use to get the default serialisation, which you can then modify.

    For example:

    from typing import Any
    from pydantic import (BaseModel, SerializerFunctionWrapHandler,
                          SerializationInfo, model_serializer)
    
    class NamedBaseModel(BaseModel):
        # Following function annotations are not required, and are for 
        # documentation purposes.
        @model_serializer(mode='wrap')
        def serialize(
                self,
                handler: SerializerFunctionWrapHandler,
                info: SerializationInfo,
        ) -> dict[str, Any]:
            result = handler(self, info)
            result['__name__'] = self.__class__.__name__
            return result
    
    class FooModel(NamedBaseModel):
        val: int
    
    class BarModel(NamedBaseModel):
        val: str
        foo: FooModel
    
    obj = BarModel(val='str', foo=FooModel(val=123))
    assert obj.model_dump() == {
        '__name__': 'BarModel',
        'val': 'str',
        'foo': {
            '__name__': 'FooModel', 
            'val': 123,
        },
    }