Search code examples
djangoormsymmetric

Symmetric django model


I want to create a model, which is symmetric on two fields. Let's call the model Balance:

class Balance (models.Model):
    payer = models.ForeignKey(auth.User, ...)
    payee = models.ForeignKey(auth.User, ...)
    amount = models.DecimalField(...)

It should have the following property:

balance_forward = Balance.objects.get(payer=USER_1, payee=USER_2)
balance_backward = Balance.objects.get(payer=USER_2, payee=USER_1)

balance_forward.amount == -1 * balance_backward.amount

What is the best way to implement this?


Solution

  • So, I came up with the following solution. Feel free to suggest other solutions.

    class SymmetricPlayersQuerySet (models.query.QuerySet):
    
        def do_swap(self, obj):
            obj.payer, obj.payee = obj.payee, obj.payer
            obj.amount *= -1
    
        def get(self, **kwargs):
            swap = False
            if "payer" in kwargs and "payee" in kwargs:
                if kwargs["payer"].id > kwargs["payee"].id:
                    swap = True
                    kwargs["payer"], kwargs["payee"] = \
                        kwargs["payee"], kwargs["payer"]
    
            obj = super().get(**kwargs)
    
            if swap:
                self.do_swap(obj)
    
            return obj
    
        def filter(self, *args, **kwargs):
            if (
                ("payer" in kwargs and "payee" not in kwargs) or
                ("payee" in kwargs and "payer" not in kwargs)
            ):
                if "payee" in kwargs:
                    key, other = "payee", "payer"
                else:
                    key, other = "payer", "payee"
    
                constraints = (
                    models.Q(payer=kwargs[key]) |
                    models.Q(payee=kwargs[key])
                )
    
                queryset = super().filter(constraints)
                for obj in queryset:
                    if getattr(obj, other) == kwargs[key]:
                        self.do_swap(obj)
    
                return queryset
    
            return super().filter(*args, **kwargs)
    
    
    class BalanceManager (models.Manager.from_queryset(SymmetricPlayersQuerySet)):
        pass
    
    
    class Balance (models.Model):
    
        objects = BalanceManager()
    
        payer = models.ForeignKey(
            Player,
            on_delete=models.CASCADE,
            related_name='balance_payer',
        )
        payee = models.ForeignKey(
            Player,
            on_delete=models.CASCADE,
            related_name='balance_payee',
        )
        amount = models.DecimalField(decimal_places=2, max_digits=1000, default=0)
    
        def do_swap(self):
            self.payer, self.payee = self.payee, self.payer
            self.amount *= -1
    
        def save(self, *args, **kwargs):
    
            swap = False
            if self.payer.id > self.payee.id:
                swap = True
                self.do_swap()
    
            result = super().save(*args, **kwargs)
    
            if swap:
                self.do_swap()
    
            return result
    
        def refresh_from_db(self, *args, **kwargs):
            swap = False
            if self.payer.id > self.payee.id:
                swap = True
    
            super().refresh_from_db(*args, **kwargs)
    
            if swap:
                self.do_swap()