Search code examples
pythonpython-3.xpydantic

Create a custom Pydantic class requiring value to be in an interval?


I'm working to build a function parameter validation library using pydantic. We want to be able to validate parameters' types and values. Types are easy enough, but I'm having a hard time creating a class to validate values. Specifically, the first class I want to build is one that requires the value to be in a user-defined interval.

So far, I've written a decorator and a function-based version of ValueInInterval. However, I would prefer to use a class-based approach. Here's an MRE of my issue:

from typing import Any
from typing_extensions import Annotated
from pydantic import Field, validate_call


class ValueInInterval:
    def __init__(
        self,
        type_definition: Any,
        start: Any,
        end: Any,
        include_start: bool = True,
        include_end: bool = True,
    ):
        self.type_definition = type_definition
        self.start = start
        self.end = end
        self.include_start = include_start
        self.include_end = include_end

    def __call__(self):
        return Annotated[self.type_definition, self.create_field()]

    def create_field(self) -> Field:
        field_config = {}
        if self.include_start:
            field_config.update({"ge": self.start})
        else:
            field_config.update({"gt": self.start})
        if self.include_end:
            field_config.update({"le": self.end})
        else:
            field_config.update({"lt": self.end})

        return Field(**field_config)

    def __get_pydantic_core_schema__(
        self,
        handler,
    ):
        schema = handler(self.type_definition)

        if self.include_start:
            schema.update({"ge": self.start})
        else:
            schema.update({"gt": self.start})

        if self.include_end:
            schema.update({"le": self.end})
        else:
            schema.update({"lt": self.end})

        return schema


@validate_call()
def test_interval(
    value: ValueInInterval(type_definition=int, start=1, end=10),
):
    print(value)


test_interval(value=1)  # should succeed
test_interval(value=0)  # should fail

Running this code in the PyCharm's Python console, I get the following error:

Traceback (most recent call last):
  File "C:\Program Files\JetBrains\PyCharm 2023.2\plugins\python\helpers\pydev\pydevconsole.py", line 364, in runcode
    coro = func()
  File "<input>", line 56, in <module>
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\validate_call_decorator.py", line 56, in validate
    validate_call_wrapper = _validate_call.ValidateCallWrapper(function, config, validate_return, local_ns)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_validate_call.py", line 57, in __init__
    schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 512, in generate_schema
    schema = self._generate_schema_inner(obj)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 789, in _generate_schema_inner
    return self.match_type(obj)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 856, in match_type
    return self._callable_schema(obj)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 1692, in _callable_schema
    arg_schema = self._generate_parameter_schema(name, annotation, p.default, parameter_mode)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 1414, in _generate_parameter_schema
    schema = self._apply_annotations(source_type, annotations)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 1890, in _apply_annotations
    schema = get_inner_schema(source_type)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_schema_generation_shared.py", line 83, in __call__
    schema = self._handler(source_type)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 1869, in inner_handler
    from_property = self._generate_schema_from_property(obj, source_type)
  File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 677, in _generate_schema_from_property
    schema = get_schema(source)
  File "<input>", line 40, in __get_pydantic_core_schema__
TypeError: __call__() takes 1 positional argument but 2 were given

I strongly suspect that I've not written the __get_pydantic_core_schema__ correctly. Note that I'm using pydantic version 2.7 and Python 3.11.


Solution

  • I was able to find a solution - I had indeed messed up when writing the __get_pydantic_core_schema__ method. Instead of calling handler as a function, its .generate_schema() method must be used, i.e. I replaced

    schema = handler(self.type_definition)
    

    with

    schema = handler.generate_schema(self.type_definition)
    

    So here is the full class definition that I ultimately went with:

    from typing import Any, Union
    from typing_extensions import Annotated
    from pydantic import Field, GetCoreSchemaHandler, validate_call
    from pydantic_core import core_schema
    
    
    class ValueInInterval:
        def __init__(
            self,
            type_definition: Any,
            start: Any = None,
            end: Any = None,
            include_start: bool = True,
            include_end: bool = True,
        ):
            self.type_definition = type_definition
            self.start = start
            self.end = end
            self.include_start = include_start
            self.include_end = include_end
    
        def __get_pydantic_core_schema__(
            self,
            source_type,
            handler,
        ):
            schema = handler.generate_schema(self.type_definition)
    
            if self.start is not None:
                if self.include_start:
                    schema.update({"ge": self.start})
                else:
                    schema.update({"gt": self.start})
    
            if self.end is not None:
                if self.include_end:
                    schema.update({"le": self.end})
                else:
                    schema.update({"lt": self.end})
    
            return schema
    
    
    # usage example
    @validate_call()
    def test_interval(
        value: ValueInInterval(type_definition=int, start=1, end=10),
    ):
        print(value)
    
    test_interval(value=1) # passes
    test_interval(value=0) # fails