I have this Pydantic model:
import typing
import pydantic
class TypeAData(pydantic.BaseModel):
aStr: str
class TypeBData(pydantic.BaseModel):
bNumber: int
class TypeCData(pydantic.BaseModel):
cBoolean: bool
class MyData(pydantic.BaseModel):
type: typing.Literal['A', 'B', 'C']
name: str
data: TypeAData | TypeBData | TypeCData
However, if type
is equal to "A" and data
contains TypeBData
, it'll validate correctly when it shouldn't. This could be an alternative:
class MyData(pydantic.BaseModel):
type: typing.Literal['A', 'B', 'C']
name: str
data: TypeAData | TypeBData | TypeCData
@pydantic.validator('data', pre=True, always=True)
def validate_data(cls, data):
if isinstance(data, dict):
data_type = data.get('type')
if data_type == 'A':
return TypeAData(**data)
elif data_type == 'B':
return TypeBData(**data)
elif data_type == 'C':
return TypeCData(**data)
raise ValueError('Invalid data or type')
It works; however, is there a better way to do it without repeating the enum keys ('A'
, 'B'
and 'C'
) and values (TypeAData
, TypeBData
and TypeCData
) twice?
I've tried using discriminated unions, but since the type
field and the discriminated fields are in different levels within the model (the latter is inside a nested object), I could not make further progress on this.
Kinda late to the party, but here it goes.
Consider the input I1:
{
"type": "A",
"name": "Model A",
"data": {"string": "Some string"},
}
Using
import typing
import pydantic
class TypeAData(pydantic.BaseModel):
string: str
class TypeBData(pydantic.BaseModel):
integer: int
class TypeCData(pydantic.BaseModel):
boolean: bool
class MyData(pydantic.BaseModel):
type: typing.Literal['A', 'B', 'C']
name: str
data: TypeAData | TypeBData | TypeCData = pydantic.Field(discriminator='type')
model = MyData.model_validate({
'type': 'A',
'name': 'Model A',
'data': {'text': 'Some string'},
})
print(model)
will, as noted, fail with
pydantic.errors.PydanticUserError:
Model 'TypeAData' needs a discriminator field for key 'type'
The first fix is to add a type
field into each union, as follows:
class TypeAData(pydantic.BaseModel):
type: typing.Literal['A'] = pydantic.Field(exclude=True, repr=False)
text: str
class TypeBData(pydantic.BaseModel):
type: typing.Literal['B'] = pydantic.Field(exclude=True, repr=False)
integer: int
class TypeCData(pydantic.BaseModel):
type: typing.Literal['C'] = pydantic.Field(exclude=True, repr=False)
boolean: int
We've added the exclude
and repr
parameters since this field is accessible only in the Python code while importing the data, and we don't want to export it messing with the schema.
However, now the error will be
pydantic_core._pydantic_core.ValidationError: 1 validation error for MyData
data
Unable to extract tag using discriminator 'type' [type=union_tag_not_found, input_value={'text': 'Some string'}, input_type=dict]
For further information visit https://errors.pydantic.dev/2.10/v/union_tag_not_found
The second fix is transform I1 into the following input I2:
{
"type": "A",
"name": "Model A",
"data": {"string": "Some string", "type": "A"},
}
This will be done by adding a model validator in MyData
as follows:
class MyData(pydantic.BaseModel):
# ...
@pydantic.model_validator(mode='wrap')
@classmethod
def discriminate_nested(
cls,
data: typing.Any,
handler: pydantic.ValidatorFunctionWrapHandler,
) -> typing.Self:
if isinstance(data, dict):
updated_data = {**data}
updated_data['data']['type'] = data['type']
return handler(updated_data)
return data
Running the code, now updated,
import typing
import pydantic
class TypeAData(pydantic.BaseModel):
type: typing.Literal['A'] = pydantic.Field(exclude=True, repr=False)
text: str
class TypeBData(pydantic.BaseModel):
type: typing.Literal['B'] = pydantic.Field(exclude=True, repr=False)
integer: int
class TypeCData(pydantic.BaseModel):
type: typing.Literal['C'] = pydantic.Field(exclude=True, repr=False)
boolean: int
class MyData(pydantic.BaseModel):
type: typing.Literal['A', 'B', 'C']
name: str
data: TypeAData | TypeBData | TypeCData = pydantic.Field(discriminator='type')
@pydantic.model_validator(mode='wrap')
@classmethod
def discriminate_nested(
cls,
data: typing.Any,
handler: pydantic.ValidatorFunctionWrapHandler,
) -> typing.Self:
if isinstance(data, dict):
updated_data = {**data}
updated_data['data']['type'] = data['type']
return handler(updated_data)
return data
model = MyData.model_validate({
'type': 'A',
'name': 'Model A',
'data': {'text': 'Some string'},
})
print(model)
will give us the expected output:
type='A' name='Model A' data=TypeAData(text='Some string')
Note the duplications, highlighted in red, green, yellow, blue and orange:
To fix that, I've create a PydanticUnionHandler
, which you can use as follows:
import pydantic
TypeData = PydanticUnionHandler(nested_field='type')
@TypeData.add('A')
class TypeAData(pydantic.BaseModel):
text: str
@TypeData.add('B')
class TypeBData(pydantic.BaseModel):
integer: int
@TypeData.add('C')
class TypeCData(pydantic.BaseModel):
boolean: int
@TypeData.register(
data_spec=PydanticUnionHandler.RegisterSpec(
field_name='data',
field=pydantic.Field(), # Optional, if you want to define an alias
),
type_spec=PydanticUnionHandler.RegisterSpec(
field_name='type',
field=pydantic.Field(), # Optional, if you want to define an alias
),
)
class MyData(pydantic.BaseModel):
name: str
model = MyData.model_validate({
'type': 'A',
'name': 'Model A',
'data': {'text': 'Some string'},
})
print(model)
This also outputs
type='A' name='Model A' data=TypeAData(text='Some string')
Note now how the duplications were reduced:
Its implementation is as follows (Python 3.12 tested):
import dataclasses
import typing
import pydantic
import pydantic.fields
import pydantic._internal._decorators
class PydanticUnionHandler:
@dataclasses.dataclass(frozen=True, kw_only=True)
class RegisterSpec:
field_name: str
field: pydantic.Field
_types: dict[type, typing.LiteralString]
def __init__(self, *, nested_field: str):
self.nested_field = nested_field
self._types = {}
def add[R: pydantic.BaseModel](
self,
name: typing.LiteralString,
) -> typing.Callable[[typing.Type[R]], typing.Type[R]]:
"""
Register a nested union model to this handler.
:param name: which name should the `type` field of the root's model have.
:return: the same decorated model, now with a new field (named by `self.nested_field`).
"""
def _(cls: typing.Type[R]) -> typing.Type[R]:
# Dynamic add the `pydantic.Field(exclude=True, repr=False)` to the decorated `cls`
cls.model_fields.update({
self.nested_field: pydantic.fields.FieldInfo.merge_field_infos(
pydantic.fields.FieldInfo.from_annotation(typing.Literal[name]),
pydantic.Field(exclude=True, repr=False),
),
})
cls.model_rebuild(force=True)
# Register the decorated `cls` into this handler
self._types[cls] = name
return cls
return _
def register[R: pydantic.BaseModel](
self,
*,
data_spec: RegisterSpec,
type_spec: RegisterSpec,
) -> typing.Callable[[typing.Type[R]], typing.Type[R]]:
"""
Modify the root model, which will contain the nested unions.
:param data_spec: how the field containing the nested unions should be declared.
:param type_spec: how the field containing the discriminator label should be declared.
:return: the same decorated model, now with two news fields (named by `data_spec.field_name` and
`type_spec.field_name), and a model validator.
"""
def _(cls: typing.Type[R]) -> typing.Type[R]:
# The model validator to be added dynamically to the model
def discriminated_nested(
_,
data: typing.Any,
handler: pydantic.ValidatorFunctionWrapHandler,
) -> typing.Self:
if isinstance(data, dict):
updated_data = {**data}
data_key: str = data_spec.field.alias or data_spec.field_name
type_key: str = type_spec.field.alias or type_spec.field_name
updated_data[data_key][self.nested_field] = updated_data[type_key]
return handler(updated_data)
return data
# Source: https://github.com/pydantic/pydantic/issues/1937#issuecomment-1853320238
cls.model_fields.update({
# Adds the nested union field
data_spec.field_name: pydantic.fields.FieldInfo.merge_field_infos(
pydantic.fields.FieldInfo.from_annotation(
typing.Annotated[
typing.Union[*[
typing.Annotated[type_, pydantic.Tag(name)]
for type_, name in self._types.items()
]],
pydantic.Discriminator(self.nested_field),
],
),
data_spec.field,
),
# Adds the discriminator label field
type_spec.field_name: pydantic.fields.FieldInfo.merge_field_infos(
pydantic.fields.FieldInfo.from_annotation(
typing.Union[*[typing.Literal[name] for name in self._types.values()]],
),
type_spec.field,
),
})
# Adds the model field
cls.discriminated_nested = classmethod(discriminated_nested)
cls.__pydantic_decorators__.model_validators.update({
discriminated_nested.__name__: pydantic._internal._decorators.Decorator.build(
cls,
cls_var_name=discriminated_nested.__name__,
shim=None,
info=pydantic._internal._decorators.ModelValidatorDecoratorInfo(mode='wrap'),
),
})
cls.model_rebuild(force=True)
return cls
return _