Search code examples
pythondrymarshmallow

Using Marshmallow without repeating myself


According to the official Marshmallow docs, it's recommended to declare a Schema and then have a separate class that receives loaded data, like this:

class UserSchema(Schema):
    name = fields.Str()
    email = fields.Email()
    created_at = fields.DateTime()

    @post_load
    def make_user(self, data):
        return User(**data)

However, my User class would look something like this:

class User:
    def __init__(name, email, created_at):
        self.name = name
        self.email = email
        self.created_at = created_at

This seems like repeating myself unnecessarily and I really don't like having to write the attribute names three more times. However, I do like IDE autocompletion and static type checking on well-defined structures.

So, is there any best practice for loading serialized data according to a Marshmallow Schema without defining another class?


Solution

  • Unless you need to deserialize as a specific class or you need custom serialization logic, you can simply do this (adapted from https://kimsereylam.com/python/2019/10/25/serialization-with-marshmallow.html):

    from marshmallow import Schema, fields
    from datetime import datetime
    
    class UserSchema(Schema):
        name = fields.Str(required=True)
        email = fields.Email()
        created_at = fields.DateTime()
    
    schema = UserSchema()
    data = { "name": "Some Guy", "email": "sguy@google.com": datetime.now() }
    user = schema.load(data)
    

    You could also create a function in your class that creates a dict with validation rules, though it would still be redundant, it would allow you to keep everything in your model class:

    class User:
        def __init__(name, email, created_at):
            self.name = name
            self.email = email
            self.created_at = created_at
    
            @classmethod
            def Schema(cls):
                return {"name": fields.Str(), "email": fields.Email(), "created_at": fields.DateTime()}
    
    UserSchema = Schema.from_dict(User.Schema)
    

    If you need to strong typing and full validation functionality, consider flask-pydantic or marshmallow-dataclass.

    marshmallow-dataclass offers a lot of similar validation features to marshmallow. It kind of ties your hands though. It doesn't have built-in support for custom fields/polymorphism (have to use using marshmallow-union instead) and doesn't seem to play well with stack-on packages like flask-marshmallow and marshmallow-sqlalchemy. https://pypi.org/project/marshmallow-dataclass/

    from typing import ClassVar, Type
    from marshmallow_dataclass import dataclasses
    from marshmallow import Schema, field, validate
    
    
    @dataclass
    class Person:
        name: str = field(metadata=dict(load_only=True))
        height: float = field(metadata=dict(validate=validate.Range(min=0)))
        Schema: ClassVar[Type[Schema]] = Schema
    
    
    Person.Schema().dump(Person('Bob', 2.0))
    # => {'height': 2.0}
    

    flask-pydantic is less elegant from a validation standpoint, but offers many of the same features and the validation is built into the class. Note that simple validations like min/max are more awkward than in marshmallow. Personally, I prefer to keep view/api logic out of the class though. https://pypi.org/project/Flask-Pydantic/

    from typing import Optional
    from flask import Flask, request
    from pydantic import BaseModel
    
    from flask_pydantic import validate
    
    app = Flask("flask_pydantic_app")
    
    class QueryModel(BaseModel):
      age: int
    
    class ResponseModel(BaseModel):
      id: int
      age: int
      name: str
      nickname: Optional[str]
    
    # Example 1: query parameters only
    @app.route("/", methods=["GET"])
    @validate()
    def get(query:QueryModel):
      age = query.age
      return ResponseModel(
        age=age,
        id=0, name="abc", nickname="123"
        )