Search code examples
pythonschemaflask-restfulmarshmallow

Python Marshmallow set different fields as 'required' for each schema variant of the same schema


I have the following schema:

class MySchema(Schema):
    id = fields.Str()
    name = fields.Str()
    value = fields.Str()
    description = fields.Str()

I want to validate some input data with Flask for different types of requests.

For GET requests, I want the id field to be required (id = fields.Str(required=True)).

For POST requests, I do not need the id to be present, but the name, value and description fields to have required=True. I know how to exclude the id when validating the input data: schema = MySchema(exclude=["id"]).

For PUT requests, I need id and at least one of name, value or description.

I could just have 3 different schemas but that I would avoid that as much as possible.

I was thinking something along the lines of schema = MySchema(require=["id"]), where I can dynamically set the required fields when I load the schema.


Solution

  • I found a nice solution to this issue, using the context property of the schema and the @validates_schema decorator.

    schema = MySchema()
    schema.context["method"] = "POST"
    
    ...
    
    schema = MySchema()
    schema.context["method"] = "PUT"
    

    In MySchema:

    from marshmallow import Schema, validates_schema, ValidationError, fields
    
    class MySchema(Schema):
        id = fields.Str()
        name = fields.Str()
        value = fields.Str()
        description = fields.Str()
    
        @validates_schema
        def validate_schema(self, data, **kwargs):
            if "method" in self.context:
                required_fields = []
                missing_fields = []
                if self.context["method"] == "POST":
                    required_fields = ["name", "value", "description"]
                
                if self.context["method"] == "PUT":
                    required_fields = ["id"]
                    
                for field in required_fields:
                    if field not in data:
                        missing_fields.append(field)
                
                if len(missing_fields) >= 1:
                    raise ValidationError({missing_field:["Missing data for required field."] for missing_field in missing_fields})
    

    This approach can be used to also pass the list of required fields as a key value pair in the context dictionary, if you only know at runtime which fields are to be required.