I'm trying to understand how python dataclass validation can be implemented straightforwardly. I'm using marshmallow validate
to try to do this but not understanding how the validation can actually be run within the dataclass, or whether it is just glommed on as a metadata field that you have to rather awkwardly run.
I could do a __post_init__
(as suggested here and here) to directly perform the validation on each field but I feel like there should be an easier, validator-agnostic way to validate all the fields according to their validate
metadata, either at __init__
or otherwise.
Here is an example script below:
from dataclasses import dataclass, field
from marshmallow import validate
def null_validate(value):
"""Validation fn for dataclass"""
if value is None:
pass
else:
raise ValidationError("{value} should be a string for this dataclass field!")
@dataclass
class Testing:
plus_minus_one: int = field(
default=None,
metadata=dict(
required=False,
validate=validate.OneOf([1, -1])
)
)
max_one: int = field(
default=None,
metadata=dict(
required=False,
validate=validate.Range(max=1)
)
)
null_field: str = field(
default=None,
metadata=dict(
required=False,
validate=null_validate
)
)
print("this passes")
it = Testing(1, 1, None)
print("this should fail")
it = Testing(10, 10, 10)
I run this as follows but don't get any ValidationError
, so I know that the validation doesn't somehow happen magically inside the dataclass:
% python testing.py
this passes
this should fail
So what I can do is add a __post_init__
method like this to the dataclass:
def __post_init__(self):
for data_field in self.__dataclass_fields__:
self.__dataclass_fields__[data_field].metadata["validate"](
self.__dict__[data_field]
)
With this, the validation more or less works on an argument-wise basis:
% python testing.py
this passes
this should fail
Traceback (most recent call last):
File "testing.py", line 47, in <module>
it = Testing(10, 10, 10)
File "<string>", line 6, in __init__
File "testing.py", line 41, in __post_init__
self.__dataclass_fields__[data_field].metadata["validate"](self.__dict__[data_field])
File "/Users/max.press/miniconda3/envs/test_env/lib/python3.7/site-packages/marshmallow/validate.py", line 569, in __call__
raise ValidationError(self._format_error(value))
marshmallow.exceptions.ValidationError: Must be one of: 1, -1.
But this seems rather clunky, and it seems hard to implement more complex validations than this. It seems like I should be able to validate "up-front" when the argument is passed in, without changing anything.
Is the solution to move to a full marshmallow-dataclass
? Possibly treating as a Schema
could handle this.
It turns out that you can do this quite easily by using marshmallow dataclasses and their Schema()
method.
The below code shows the desired behavior without the __post_init__
, though I clearly need to read up more on marshmallow:
from dataclasses import dataclass, field
from marshmallow import validate, Schema
from marshmallow_dataclass import dataclass
def null_validate(value):
"""Validation fn for dataclass"""
if value is None:
pass
else:
raise ValidationError("{value} should be a string for this dataclass field!")
@dataclass
class Testing:
plus_minus_one: int = field(
default=None,
metadata=dict(
required=False,
validate=validate.OneOf([1, -1])
)
)
max_one: int = field(
default=None,
metadata=dict(
required=False,
validate=validate.Range(max=1)
)
)
null_field: NoneType = field(
default=None,
metadata=dict(
required=False,
validate=null_validate
)
)
print("this passes")
it = Testing.Schema().load({"plus_minus_one": 1, "max_one": 1, "null_field": None})
print("this should fail")
it = Testing.Schema().load({"plus_minus_one": 10, "max_one": 10, "null_field": 10})
When it's run, I get the desired result:
this passes
this should fail
[...]
marshmallow.exceptions.ValidationError: {'null_field': ['Not a valid string.'], 'plus_minus_one': ['Must be one of: 1, -1.'], 'max_one': ['Must be less than or equal to 1.']}