Search code examples
pythonpydanticdiscriminated-union

How to make Pydantic discriminate a nested object based on a field?


I have this Pydantic model:

import typing
import pydantic

class TypeAData(pydantic.BaseModel):
    aStr: str

class TypeBData(pydantic.BaseModel):
    bNumber: int

class TypeCData(pydantic.BaseModel):
    cBoolean: bool

class MyData(pydantic.BaseModel):
    type: typing.Literal['A', 'B', 'C']
    name: str
    data: TypeAData | TypeBData | TypeCData

However, if type is equal to "A" and data contains TypeBData, it'll validate correctly when it shouldn't. This could be an alternative:

class MyData(pydantic.BaseModel):
    type: typing.Literal['A', 'B', 'C']
    name: str
    data: TypeAData | TypeBData | TypeCData

    @pydantic.validator('data', pre=True, always=True)
    def validate_data(cls, data):
        if isinstance(data, dict):
            data_type = data.get('type')
            if data_type == 'A':
                return TypeAData(**data)
            elif data_type == 'B':
                return TypeBData(**data)
            elif data_type == 'C':
                return TypeCData(**data)

        raise ValueError('Invalid data or type')

It works; however, is there a better way to do it without repeating the enum keys ('A', 'B' and 'C') and values (TypeAData, TypeBData and TypeCData) twice?

I've tried using discriminated unions, but since the type field and the discriminated fields are in different levels within the model (the latter is inside a nested object), I could not make further progress on this.


Solution

  • Kinda late to the party, but here it goes.

    Consider the input I1:

    {
        "type": "A",
        "name": "Model A",
        "data": {"string": "Some string"},
    }
    

    Using

    import typing
    
    import pydantic
    
    class TypeAData(pydantic.BaseModel):
        string: str
    
    class TypeBData(pydantic.BaseModel):
        integer: int
    
    class TypeCData(pydantic.BaseModel):
        boolean: bool
    
    class MyData(pydantic.BaseModel):
        type: typing.Literal['A', 'B', 'C']
        name: str
        data: TypeAData | TypeBData | TypeCData = pydantic.Field(discriminator='type')
    
    model = MyData.model_validate({
        'type': 'A',
        'name': 'Model A',
        'data': {'text': 'Some string'},
    })
    print(model)
    

    will, as noted, fail with

    pydantic.errors.PydanticUserError: 
        Model 'TypeAData' needs a discriminator field for key 'type'
    

    The first fix is to add a type field into each union, as follows:

    class TypeAData(pydantic.BaseModel):
        type: typing.Literal['A'] = pydantic.Field(exclude=True, repr=False)
        text: str
    
    class TypeBData(pydantic.BaseModel):
        type: typing.Literal['B'] = pydantic.Field(exclude=True, repr=False)
        integer: int
    
    class TypeCData(pydantic.BaseModel):
        type: typing.Literal['C'] = pydantic.Field(exclude=True, repr=False)
        boolean: int
    

    We've added the exclude and repr parameters since this field is accessible only in the Python code while importing the data, and we don't want to export it messing with the schema.

    However, now the error will be

    pydantic_core._pydantic_core.ValidationError: 1 validation error for MyData
    data
      Unable to extract tag using discriminator 'type' [type=union_tag_not_found, input_value={'text': 'Some string'}, input_type=dict]
        For further information visit https://errors.pydantic.dev/2.10/v/union_tag_not_found
    

    The second fix is transform I1 into the following input I2:

    {
        "type": "A",
        "name": "Model A",
        "data": {"string": "Some string", "type": "A"},
    }
    

    This will be done by adding a model validator in MyData as follows:

    class MyData(pydantic.BaseModel):
        # ...
    
        @pydantic.model_validator(mode='wrap')
        @classmethod
        def discriminate_nested(
            cls,
            data: typing.Any,
            handler: pydantic.ValidatorFunctionWrapHandler,
        ) -> typing.Self:
            if isinstance(data, dict):
                updated_data = {**data}
                updated_data['data']['type'] = data['type']
                return handler(updated_data)
            return data
    

    Running the code, now updated,

    import typing
    
    import pydantic
    
    class TypeAData(pydantic.BaseModel):
        type: typing.Literal['A'] = pydantic.Field(exclude=True, repr=False)
        text: str
    
    class TypeBData(pydantic.BaseModel):
        type: typing.Literal['B'] = pydantic.Field(exclude=True, repr=False)
        integer: int
    
    class TypeCData(pydantic.BaseModel):
        type: typing.Literal['C'] = pydantic.Field(exclude=True, repr=False)
        boolean: int
    
    class MyData(pydantic.BaseModel):
        type: typing.Literal['A', 'B', 'C']
        name: str
        data: TypeAData | TypeBData | TypeCData = pydantic.Field(discriminator='type')
    
        @pydantic.model_validator(mode='wrap')
        @classmethod
        def discriminate_nested(
                cls,
                data: typing.Any,
                handler: pydantic.ValidatorFunctionWrapHandler,
        ) -> typing.Self:
            if isinstance(data, dict):
                updated_data = {**data}
                updated_data['data']['type'] = data['type']
                return handler(updated_data)
            return data
    
    model = MyData.model_validate({
        'type': 'A',
        'name': 'Model A',
        'data': {'text': 'Some string'},
    })
    print(model)
    

    will give us the expected output:

    type='A' name='Model A' data=TypeAData(text='Some string')
    

    But we can do better

    Note the duplications, highlighted in red, green, yellow, blue and orange:

    Code 1

    To fix that, I've create a PydanticUnionHandler, which you can use as follows:

    import pydantic
    
    TypeData = PydanticUnionHandler(nested_field='type')
    
    @TypeData.add('A')
    class TypeAData(pydantic.BaseModel):
        text: str
    
    @TypeData.add('B')
    class TypeBData(pydantic.BaseModel):
        integer: int
    
    @TypeData.add('C')
    class TypeCData(pydantic.BaseModel):
        boolean: int
    
    @TypeData.register(
        data_spec=PydanticUnionHandler.RegisterSpec(
            field_name='data',
            field=pydantic.Field(),  # Optional, if you want to define an alias
        ),
        type_spec=PydanticUnionHandler.RegisterSpec(
            field_name='type',
            field=pydantic.Field(),  # Optional, if you want to define an alias
        ),
    )
    class MyData(pydantic.BaseModel):
        name: str
    
    model = MyData.model_validate({
        'type': 'A',
        'name': 'Model A',
        'data': {'text': 'Some string'},
    })
    print(model)
    

    This also outputs

    type='A' name='Model A' data=TypeAData(text='Some string')
    

    Note now how the duplications were reduced:

    Code 2

    Its implementation is as follows (Python 3.12 tested):

    import dataclasses
    import typing
    
    import pydantic
    import pydantic.fields
    import pydantic._internal._decorators
    
    class PydanticUnionHandler:
        @dataclasses.dataclass(frozen=True, kw_only=True)
        class RegisterSpec:
            field_name: str
            field: pydantic.Field
    
        _types: dict[type, typing.LiteralString]
    
        def __init__(self, *, nested_field: str):
            self.nested_field = nested_field
            self._types = {}
    
        def add[R: pydantic.BaseModel](
                self,
                name: typing.LiteralString,
        ) -> typing.Callable[[typing.Type[R]], typing.Type[R]]:
            """
            Register a nested union model to this handler.
    
            :param name: which name should the `type` field of the root's model have.
            :return: the same decorated model, now with a new field (named by `self.nested_field`).
            """
    
            def _(cls: typing.Type[R]) -> typing.Type[R]:
                # Dynamic add the `pydantic.Field(exclude=True, repr=False)` to the decorated `cls`
                cls.model_fields.update({
                    self.nested_field: pydantic.fields.FieldInfo.merge_field_infos(
                        pydantic.fields.FieldInfo.from_annotation(typing.Literal[name]),
                        pydantic.Field(exclude=True, repr=False),
                    ),
                })
                cls.model_rebuild(force=True)
    
                # Register the decorated `cls` into this handler
                self._types[cls] = name
    
                return cls
    
            return _
    
        def register[R: pydantic.BaseModel](
                self,
                *,
                data_spec: RegisterSpec,
                type_spec: RegisterSpec,
        ) -> typing.Callable[[typing.Type[R]], typing.Type[R]]:
            """
            Modify the root model, which will contain the nested unions.
    
            :param data_spec: how the field containing the nested unions should be declared.
            :param type_spec: how the field containing the discriminator label should be declared.
            :return: the same decorated model, now with two news fields (named by `data_spec.field_name` and
                `type_spec.field_name), and a model validator.
            """
    
            def _(cls: typing.Type[R]) -> typing.Type[R]:
                # The model validator to be added dynamically to the model
                def discriminated_nested(
                        _,
                        data: typing.Any,
                        handler: pydantic.ValidatorFunctionWrapHandler,
                ) -> typing.Self:
                    if isinstance(data, dict):
                        updated_data = {**data}
                        data_key: str = data_spec.field.alias or data_spec.field_name
                        type_key: str = type_spec.field.alias or type_spec.field_name
                        updated_data[data_key][self.nested_field] = updated_data[type_key]
                        return handler(updated_data)
                    return data
    
                # Source: https://github.com/pydantic/pydantic/issues/1937#issuecomment-1853320238
                cls.model_fields.update({
                    # Adds the nested union field
                    data_spec.field_name: pydantic.fields.FieldInfo.merge_field_infos(
                        pydantic.fields.FieldInfo.from_annotation(
                            typing.Annotated[
                                typing.Union[*[
                                    typing.Annotated[type_, pydantic.Tag(name)]
                                    for type_, name in self._types.items()
                                ]],
                                pydantic.Discriminator(self.nested_field),
                            ],
                        ),
                        data_spec.field,
                    ),
                    # Adds the discriminator label field
                    type_spec.field_name: pydantic.fields.FieldInfo.merge_field_infos(
                        pydantic.fields.FieldInfo.from_annotation(
                            typing.Union[*[typing.Literal[name] for name in self._types.values()]],
                        ),
                        type_spec.field,
                    ),
                })
    
                # Adds the model field
                cls.discriminated_nested = classmethod(discriminated_nested)
                cls.__pydantic_decorators__.model_validators.update({
                    discriminated_nested.__name__: pydantic._internal._decorators.Decorator.build(
                        cls,
                        cls_var_name=discriminated_nested.__name__,
                        shim=None,
                        info=pydantic._internal._decorators.ModelValidatorDecoratorInfo(mode='wrap'),
                    ),
                })
                cls.model_rebuild(force=True)
    
                return cls
    
            return _