Search code examples
pydanticpydantic-v2

Elegant way to apply constraints based on nested discriminators in pydantic


Pydantic allows list's element with nested discriminated unions. Is there any elegant way to apply constraints (such as MaxLen, MinLen) on sublist based on inner discriminator without writing custom validator?

For example: In below PetModel, limit number of Cat to MinLen(1) and Dog to MaxLen(2)

PetsModel.model_validate(
    {
        "pets": [
            {"pet_type": "cat", "color": "black", "black_name": "black_cat_name"},
            {"pet_type": "dog", "name": "dog_name"},
        ]
    }
)

Adapted from https://docs.pydantic.dev/latest/concepts/unions/#nested-discriminated-unions

from typing import Literal, Union

from typing_extensions import Annotated

from pydantic import BaseModel, Field, ValidationError


class BlackCat(BaseModel):
    pet_type: Literal['cat']
    color: Literal['black']
    black_name: str


class WhiteCat(BaseModel):
    pet_type: Literal['cat']
    color: Literal['white']
    white_name: str


Cat = Annotated[Union[BlackCat, WhiteCat], Field(discriminator='color')]


class Dog(BaseModel):
    pet_type: Literal['dog']
    name: str


Pet = Annotated[Union[Cat, Dog], Field(discriminator='pet_type')]


class PetsModel(BaseModel):
    pets: list[Pet]  # list of pet

I know that we can apply constraints on pets list, as follows

from annotated_types import MaxLen, MinLen

class PetsModel(BaseModel):
    pets: Annotated[list[Pet], MinLen(1), MaxLen(5)]  # list of pet

But, I want to apply constraints on number of Cat and Dog which are elements of pets list.


Solution

  • You could create a field validator on pets after [pydantic] validation, so you'd be checking an array of objects like so:

    from typing import Union, List
    from annotated_types import MaxLen, MinLen
    from collections import Counter
    
    from pydantic import field_validator
    
    
    class PetsModel(BaseModel):
        pets: Annotated[List[Pet], MinLen(1), MaxLen(5)]
    
    
        @field_validator('pets')
        @classmethod
        def special_rules(cls, v: List[Pet]) -> str:
            rules = {
                'Dog': {'min': 1, 'max': None},
                'BlackCat': {'min': None, 'max': 5}
            }
            # This will count child class name (BlackCat, not Cat)
            c = Counter([x.__class__.__name__ for x in v])
            
            # Handy helper function
            replace_none = lambda x, y: x if x is not None else y
    
            for key in rules:
                x_min = replace_none(rules[key].get('min'), 0)
                x_max = replace_none(rules[key].get('max'), float('inf'))
                inbounds = x_min <= c.get(key, 0) <= x_max
                if not inbounds:
                    raise ValueError("not valid value")
    
    

    Now, you can be even fancier and put the rules object in your field via json_schema_extra like so:

    
    rules = {
        'Dog': {'min': 1, 'max': None},
        'BlackCat': {'min': None, 'max': 5}
    }
    
    
    class PetsModel(BaseModel):
        pets: Annotated[
            List[Pet], MinLen(1), MaxLen(5),
            Field(json_schema_extra={"rules": rules})
        ]
    
        @field_validator('pets')
        @classmethod
        def special_rules(cls, v: List[Pet], info: ValidationInfo) -> str:
            rules = cls.model_fields[info.field_name].json_schema_extra["rules"]
            ...
    

    Hope this helps!