Search code examples
pythonvalidationpython-dataclassesmarshmallow

Python dataclass validation: an easy way?


I'm trying to understand how python dataclass validation can be implemented straightforwardly. I'm using marshmallow validate to try to do this but not understanding how the validation can actually be run within the dataclass, or whether it is just glommed on as a metadata field that you have to rather awkwardly run.

I could do a __post_init__ (as suggested here and here) to directly perform the validation on each field but I feel like there should be an easier, validator-agnostic way to validate all the fields according to their validate metadata, either at __init__ or otherwise.

Here is an example script below:

from dataclasses import dataclass, field
from marshmallow import validate


def null_validate(value):
    """Validation fn for dataclass"""
    if value is None:
        pass
    else:
        raise ValidationError("{value} should be a string for this dataclass field!")


@dataclass
class Testing:
    plus_minus_one: int = field(
        default=None,
        metadata=dict(
            required=False,
            validate=validate.OneOf([1, -1])
        )
    )
    max_one: int = field(
        default=None,
        metadata=dict(
            required=False,
            validate=validate.Range(max=1)
        )
    )
    null_field: str = field(
        default=None,
        metadata=dict(
            required=False,
            validate=null_validate
        )
    )

print("this passes")
it = Testing(1, 1, None)
print("this should fail")
it = Testing(10, 10, 10)

I run this as follows but don't get any ValidationError, so I know that the validation doesn't somehow happen magically inside the dataclass:

% python testing.py
this passes
this should fail

So what I can do is add a __post_init__ method like this to the dataclass:

def __post_init__(self):
    for data_field in self.__dataclass_fields__:
        self.__dataclass_fields__[data_field].metadata["validate"](
            self.__dict__[data_field]
        )

With this, the validation more or less works on an argument-wise basis:

% python testing.py
this passes
this should fail
Traceback (most recent call last):
  File "testing.py", line 47, in <module>
    it = Testing(10, 10, 10)
  File "<string>", line 6, in __init__
  File "testing.py", line 41, in __post_init__
    self.__dataclass_fields__[data_field].metadata["validate"](self.__dict__[data_field])
  File "/Users/max.press/miniconda3/envs/test_env/lib/python3.7/site-packages/marshmallow/validate.py", line 569, in __call__
    raise ValidationError(self._format_error(value))
marshmallow.exceptions.ValidationError: Must be one of: 1, -1.

But this seems rather clunky, and it seems hard to implement more complex validations than this. It seems like I should be able to validate "up-front" when the argument is passed in, without changing anything.

Is the solution to move to a full marshmallow-dataclass? Possibly treating as a Schema could handle this.


Solution

  • It turns out that you can do this quite easily by using marshmallow dataclasses and their Schema() method.

    The below code shows the desired behavior without the __post_init__, though I clearly need to read up more on marshmallow:

    from dataclasses import dataclass, field
    from marshmallow import validate, Schema
    from marshmallow_dataclass import dataclass
    
    
    
    def null_validate(value):
        """Validation fn for dataclass"""
        if value is None:
            pass
        else:
            raise ValidationError("{value} should be a string for this dataclass field!")
    
    
    @dataclass
    class Testing:
        plus_minus_one: int = field(
            default=None,
            metadata=dict(
                required=False,
                validate=validate.OneOf([1, -1])
            )
        )
        max_one: int = field(
            default=None,
            metadata=dict(
                required=False,
                validate=validate.Range(max=1)
            )
        )
        null_field: NoneType = field(
            default=None,
            metadata=dict(
                required=False,
                validate=null_validate
            )
        )
    
    print("this passes")
    it = Testing.Schema().load({"plus_minus_one": 1, "max_one": 1, "null_field": None})
    print("this should fail")
    it = Testing.Schema().load({"plus_minus_one": 10, "max_one": 10, "null_field": 10})
    

    When it's run, I get the desired result:

    this passes
    this should fail
    [...]
    marshmallow.exceptions.ValidationError: {'null_field': ['Not a valid string.'], 'plus_minus_one': ['Must be one of: 1, -1.'], 'max_one': ['Must be less than or equal to 1.']}