Search code examples
pythonflaskflask-sqlalchemyflask-marshmallow

Pass parameters to a decorator Class in Python


I'm trying to write a decorator for a Db Model, to makes this Model serializable

def Schema(cls):
    class Schema(marshmallow.ModelSchema):
        class Meta:
            model = cls

    cls.Schema = Schema
    return cls


@Schema
class SerialInterface(sql.Model, InheritanceModel):
    id = sql.Column(types.Integer, primary_key=True)
    transmission_rate = sql.Column(types.Integer)
    type = sql.Column(sql.String(50))

    mad_id = sql.Column(types.Integer, sql.ForeignKey('mad.id'))

    serial_protocol = sql.relationship(SerialProtocol, uselist=False, cascade="all, delete-orphan")

But I want to pass the nested Objects in this Decorator, Like this:

@Schema(nested=['serial_protocol'])
class SerialInterface(sql.Model, InheritanceModel):

Solution

  • You can do something like:

    def Schema(*args, **kwargs):
        def wrapped(cls):
            class Schema(marshmallow.ModelSchema):
                class Meta:
                    model = cls
    
            cls.Schema = Schema
            return cls
        return wrapped
    

    And then doing @Schema(nested=['serial_protocol']) will work.

    How this works is, you create a function that takes arguments and returns a decorator. From there the decorator works like a normal Python decorator.

    @Schema(nested=['serial_protocol'])
    class SerialInterface:
        ...
    

    The decorator translates to:

    SerialInterface = Schema(nested=['serial_protocol'])(SerialInterface)
    

    Extra tip, Use functools.wraps module :) See why