Search code examples
pythondecorator

How to create a python decorator whose args are the decorated function plus any arbitrary argument(s)


I've created decorators that wrap functions before, but in this instance, I don't need to wrap, so I'm guessing I'm using the wrong paradigm, so maybe somebody can help me figure this out and solve my ultimate goal.

What I imagined was a decorator that, when called (during compile time), it takes 3 arguments:

  • The decorated function (that resides inside a Model class)
  • The name of a data member of the class (i.e. a database field, e.g. a name field of type CharField)
  • The name of a parent key data member in the class (e.g. parent of type ForeignKey)

My decorator code would register the function, the field, and related key associated with it in a global list variable.

I would then have a class that inherits from Model that over-rides save() and delete(). If would cycle through the global list to update the associated fields using the output of the function and then call the parent model's .save() method so that it would update its decorated fields as well.

I quickly realized though that the decorator isn't passing the function that has the decorator, because I get an exception I created for when there isn't a field or a parent key supplied to the decorator during compile time.

In case this isn't clear, here's the code I have:

updater_list: Dict[str, List] = {}


def field_updater_function(fn, update_field_name=None, parent_field_name=None):
    """
    This is a decorator for functions in a Model class that are identified to be used to update a supplied field and
    fields of any linked parent record (if the record is changed).  The function should return a value compatible with
    the field type supplied.  These decorators are identified by the MaintainedModel class, whose save and delete
    methods override the parent model and call the given functions to update the supplied field.  It also calls linked
    dependent models (if supplied) update methods.
    """

    if update_field_name is None and parent_field_name is None:
        raise Exception(
            "Either an update_field_name or parent_field_name argument is required."
        )

    # Get the name of the class the function belongs to
    class_name = fn.__qualname__.split(".")[0]
    func_dict = {
        "function": fn.__name__,
        "update_field": update_field_name,
        "parent_field": parent_field_name,
    }
    if class_name in updater_list:
        updater_list[class_name].append(func_dict)
    else:
        updater_list[class_name] = [func_dict]
    if settings.DEBUG:
        print(f"Added field_updater_function decorator to function {fn.__qualname__}")
    return fn


class MaintainedModel(Model):
    """
    This class maintains database field values for a django.models.Model class whose values can be derived using a
    function.  If a record changes, the decorated function is used to update the field value.  It can also propagate
    changes of records in linked models.  Every function in the derived class decorated with the
    `@field_updater_function` decorator (defined above, outside this class) will be called and the associated field
    will be updated.  Only methods that take no arguments are supported.  This class overrides the class's save and
    delete methods as triggers for the updates.
    """

    def save(self, *args, **kwargs):
        # Set the changed value triggering this update
        super().save(*args, **kwargs)
        # Update the fields that change due to the above change (if any)
        self.update_decorated_fields()
        # Now save the updated values (i.e. save again)
        super().save(*args, **kwargs)
        # Percolate changes up to the parents (if any)
        self.call_parent_updaters()

    def delete(self, *args, **kwargs):
        # Delete the record triggering this update
        super().delete(*args, **kwargs)  # Call the "real" delete() method.
        # Percolate changes up to the parents (if any)
        self.call_parent_updaters()

    def update_decorated_fields(self):
        """
        Updates every field identified in each field_updater_function decorator that generates its value
        """
        for updater_dict in self.get_my_updaters():
            update_fun = getattr(self, updater_dict["function"])
            update_fld = updater_dict["update_field"]
            if update_fld is not None:
                setattr(self, update_fld, update_fun())

    def call_parent_updaters(self):
        parents = []
        for updater_dict in self.get_my_updaters():
            parent_fld = getattr(self, updater_dict["parent_field"])
            if parent_fld is not None and parent_fld not in parents:
                parents.append(parent_fld)

        for parent_fld in parents:
            parent_instance = getattr(self, parent_fld)
            if isinstance(parent_instance, MaintainedModel):
                parent_instance.save()
            elif isinstance(parent_instance, ManyToManyField) and :
                parent_instance.all().save()
            else:
                raise Exception(
                    f"Parent class {parent_instance.__class__.__name__} or {self.__class__.__name__} must inherit "
                    f"from {MaintainedModel.__name__}."
                )

    @classmethod
    def get_my_updaters(cls):
        """
        Convenience method to retrieve all the updater functions of the calling model.
        """
        if cls.__name__ in updater_list:
            return updater_list[cls.__name__]
        else:
            return []

    class Meta:
        abstract = True

And here's the first decorator I applied that triggers the exception at compiletime:

class Tracer(models.Model, TracerLabeledClass):

    id = models.AutoField(primary_key=True)
    name = models.CharField(
        max_length=256,
        unique=True,
        help_text="A unique name or lab identifier of the tracer, e.g. 'lysine-C14'.",
    )
    compound = models.ForeignKey(
        to="DataRepo.Compound",
        on_delete=models.RESTRICT,
        null=False,
        related_name="tracer",
    )

    class Meta:
        verbose_name = "tracer"
        verbose_name_plural = "tracers"
        ordering = ["name"]

    def __str__(self):
        return str(self._name())

    @field_updater_function("name", "infusates")
    def _name(self):
        # format: `compound - [ labelname,labelname,... ]` (but no spaces)
        if self.labels is None or self.labels.count() == 0:
            return self.compound.name
        return (
            self.compound.name
            + "-["
            + ",".join(list(map(lambda l: str(l), self.labels.all())))
            + "]"
        )

And my exception:

...
  File ".../tracer.py", line 31, in Tracer
    @field_updater_function("name")
  File ".../maintained_model.py", line 19, in field_updater_function
    raise Exception(
Exception: Either an update_field_name or parent_field_name argument is required.

The basic idea is we have a bunch of fields in the database that can be derived fully from other fields in the database. We'd started out originally with cached_properties, but they provided virtually no speedup, so we'd rather just save the computed values in the database.

I'd written a caching mechanism which auto-refreshes the cache using an override of .save and .delete, and that works great, but has various drawbacks.

We could custom code an override of .save() that explicitly calls the function to update every field, but I wanted something that made the overhead of maintaining field values as simple as applying decorators to the functions that perform the updates, and just supply the fields they compute the values for and the links to other affected fields up the hierarchy. Such as:

    @field_updater_function("name", "infusates")
    def _name(self):
        ...

Is there something other than decorators I should be using to accomplish this? I could just make a dummy decorator using functools.wraps that just returns the supplied function as is (I think), but that just feels wrong.


Solution

  • You need to make a decorator factory. That is, a function you call with arguments that returns a decorator function that gets passed the function to be decorated.

    A typical way to do that is with nested functions. A function defined within another function can access the variables in the enclosing function's namespace. Here's what I think it would look like for your code:

    def field_updater_function(update_field_name=None, parent_field_name=None): # factory
        # docstring omitted for brevity
        if update_field_name is None and parent_field_name is None:
            raise Exception(
                "Either an update_field_name or parent_field_name argument is required."
            )
    
        def decorator(fn):                                         # the actual decorator
            class_name = fn.__qualname__.split(".")[0]
            func_dict = {
                "function": fn.__name__,
                "update_field": update_field_name,        # you can still access variables
                "parent_field": parent_field_name,        # from the enclosing namespace
            }
            if class_name in updater_list:
                updater_list[class_name].append(func_dict)
            else:
                updater_list[class_name] = [func_dict]
            if settings.DEBUG:
                print(f"Added field_updater_function decorator to function {fn.__qualname__}")
            return fn
    
        return decorator                          # here the factory returns the decorator