Search code examples
pythonflasksqlalchemymarshmallow

Handling multiple variants of a marshmallow schema


I have a simple Flask-SQLAlchemy model, which I'm writing a REST API for:

class Report(db.Model, CRUDMixin):
    report_id = Column(Integer, primary_key=True)
    user_id = Column(Integer, ForeignKey('users.user_id'), index=True)
    report_hash = Column(Unicode, index=True, unique=True)
    created_at = Column(DateTime, nullable=False, default=dt.datetime.utcnow)
    uploaded_at = Column(DateTime, nullable=False, default=dt.datetime.utcnow)

Then I have the corresponding Marshmallow-SQLAlchemy schema:

class ReportSchema(ModelSchema):
    class Meta:
        model = Report

However, in my rest API, I need to be able to dump and load slightly different variants of this model:

  • When dumping all the reports (e.g. GET /reports), I want to dump all the above fields.
  • When dumping a single report (e.g. GET /reports/1), I want to dump all this data, and also all associated relations, such as the associated Sample objects from the sample table (one report has many Samples)
  • When creating a new report (e.g. POST /reports), I want the user to provide all the report fields except report_id (which will be generated), report_hash and uploaded_at (which will be calculated on the spot), and also I want them to include all the associated Sample objects in their upload.

How can I reasonably maintain 3 (or more) versions of this schema? Should I:

  • Have 3 separate ModelSchema subclasses? e.g. AggregateReportSchema, SingleReportSchema, and UploadReportSchema?
  • Have one mega-ModelSchema that includes all fields I could ever want in this schema, and then I subtract fields from it on the fly using the exclude argument in the constructor? e.g. ReportSchema(exclude=[])?
  • Or should I use inheritance and define a class ReportBaseSchema(ModelSchema), and the other schemas subclass this to add additional fields (e.g. class UploadReportSchema(ReportBaseSchema))?
  • Something else?

Solution

  • Since asking this question, I've done a ton of work using Marshmallow, so hopefully I can explain somewhat.

    My rule of thumb is this: do as much as you can with the schema constructor (option #2), and only resort to inheritance (option #3) if you absolutely have to. Never use option #1, because that will result in unnecessary, duplicated code.

    The schema constructor approach is great because:

    • You end up writing the least code
    • You never have to duplicate logic (e.g. validation)
    • The only, exclude, partial and unknown arguments to the schema constructor give you more than enough power to customize the individual schemas (see the documentation)
    • Schema subclasses can add extra settings to the schema constructor. For example marshmallow-jsonapi addds include_data, which lets you control the amount of data you return for each related resource

    My original post is a situation where using the schema constructor is sufficient. You should first define a schema that includes all possibly related fields, including relationships that might be a Nested field. Then, if there are related resources or superfluous fields you don't want to include in the response sometimes, you can simply use Report(exclude=['some', 'fields']).dump() in that view method.

    However, an example I've encountered where using inheritance was a better fit was when I modelled the arguments for certain graphs I was generating. Here, I wanted general arguments that would be passed into the underlying plotting library, but I wanted the child schemas to refine the schema and use more specific validations:

    class PlotSchema(Schema):
        """
        Data that can be used to generate a plot
        """
        id = f.String(dump_only=True)
        type = f.String()
        x = f.List(f.Raw())
        y = f.List(f.Raw())
        text = f.List(f.Raw())
        hoverinfo = f.Str()
    
    
    class TrendSchema(PlotSchema):
        """
        Data that can be used to generate a trend plot
        """
        x = f.List(f.DateTime())
        y = f.List(f.Number())