I'm working to build a function parameter validation library using pydantic
. We want to be able to validate parameters' types and values. Types are easy enough, but I'm having a hard time creating a class to validate values. Specifically, the first class I want to build is one that requires the value to be in a user-defined interval.
So far, I've written a decorator and a function-based version of ValueInInterval
. However, I would prefer to use a class-based approach. Here's an MRE of my issue:
from typing import Any
from typing_extensions import Annotated
from pydantic import Field, validate_call
class ValueInInterval:
def __init__(
self,
type_definition: Any,
start: Any,
end: Any,
include_start: bool = True,
include_end: bool = True,
):
self.type_definition = type_definition
self.start = start
self.end = end
self.include_start = include_start
self.include_end = include_end
def __call__(self):
return Annotated[self.type_definition, self.create_field()]
def create_field(self) -> Field:
field_config = {}
if self.include_start:
field_config.update({"ge": self.start})
else:
field_config.update({"gt": self.start})
if self.include_end:
field_config.update({"le": self.end})
else:
field_config.update({"lt": self.end})
return Field(**field_config)
def __get_pydantic_core_schema__(
self,
handler,
):
schema = handler(self.type_definition)
if self.include_start:
schema.update({"ge": self.start})
else:
schema.update({"gt": self.start})
if self.include_end:
schema.update({"le": self.end})
else:
schema.update({"lt": self.end})
return schema
@validate_call()
def test_interval(
value: ValueInInterval(type_definition=int, start=1, end=10),
):
print(value)
test_interval(value=1) # should succeed
test_interval(value=0) # should fail
Running this code in the PyCharm's Python console, I get the following error:
Traceback (most recent call last):
File "C:\Program Files\JetBrains\PyCharm 2023.2\plugins\python\helpers\pydev\pydevconsole.py", line 364, in runcode
coro = func()
File "<input>", line 56, in <module>
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\validate_call_decorator.py", line 56, in validate
validate_call_wrapper = _validate_call.ValidateCallWrapper(function, config, validate_return, local_ns)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_validate_call.py", line 57, in __init__
schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 512, in generate_schema
schema = self._generate_schema_inner(obj)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 789, in _generate_schema_inner
return self.match_type(obj)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 856, in match_type
return self._callable_schema(obj)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 1692, in _callable_schema
arg_schema = self._generate_parameter_schema(name, annotation, p.default, parameter_mode)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 1414, in _generate_parameter_schema
schema = self._apply_annotations(source_type, annotations)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 1890, in _apply_annotations
schema = get_inner_schema(source_type)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_schema_generation_shared.py", line 83, in __call__
schema = self._handler(source_type)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 1869, in inner_handler
from_property = self._generate_schema_from_property(obj, source_type)
File "C:\projects\django-postgres-loader\venv\lib\site-packages\pydantic\_internal\_generate_schema.py", line 677, in _generate_schema_from_property
schema = get_schema(source)
File "<input>", line 40, in __get_pydantic_core_schema__
TypeError: __call__() takes 1 positional argument but 2 were given
I strongly suspect that I've not written the __get_pydantic_core_schema__
correctly. Note that I'm using pydantic
version 2.7 and Python 3.11.
I was able to find a solution - I had indeed messed up when writing the __get_pydantic_core_schema__
method. Instead of calling handler
as a function, its .generate_schema()
method must be used, i.e. I replaced
schema = handler(self.type_definition)
with
schema = handler.generate_schema(self.type_definition)
So here is the full class definition that I ultimately went with:
from typing import Any, Union
from typing_extensions import Annotated
from pydantic import Field, GetCoreSchemaHandler, validate_call
from pydantic_core import core_schema
class ValueInInterval:
def __init__(
self,
type_definition: Any,
start: Any = None,
end: Any = None,
include_start: bool = True,
include_end: bool = True,
):
self.type_definition = type_definition
self.start = start
self.end = end
self.include_start = include_start
self.include_end = include_end
def __get_pydantic_core_schema__(
self,
source_type,
handler,
):
schema = handler.generate_schema(self.type_definition)
if self.start is not None:
if self.include_start:
schema.update({"ge": self.start})
else:
schema.update({"gt": self.start})
if self.end is not None:
if self.include_end:
schema.update({"le": self.end})
else:
schema.update({"lt": self.end})
return schema
# usage example
@validate_call()
def test_interval(
value: ValueInInterval(type_definition=int, start=1, end=10),
):
print(value)
test_interval(value=1) # passes
test_interval(value=0) # fails