I want to create a model, which is symmetric on two fields. Let's call the model Balance:
class Balance (models.Model):
payer = models.ForeignKey(auth.User, ...)
payee = models.ForeignKey(auth.User, ...)
amount = models.DecimalField(...)
It should have the following property:
balance_forward = Balance.objects.get(payer=USER_1, payee=USER_2)
balance_backward = Balance.objects.get(payer=USER_2, payee=USER_1)
balance_forward.amount == -1 * balance_backward.amount
What is the best way to implement this?
So, I came up with the following solution. Feel free to suggest other solutions.
class SymmetricPlayersQuerySet (models.query.QuerySet):
def do_swap(self, obj):
obj.payer, obj.payee = obj.payee, obj.payer
obj.amount *= -1
def get(self, **kwargs):
swap = False
if "payer" in kwargs and "payee" in kwargs:
if kwargs["payer"].id > kwargs["payee"].id:
swap = True
kwargs["payer"], kwargs["payee"] = \
kwargs["payee"], kwargs["payer"]
obj = super().get(**kwargs)
if swap:
self.do_swap(obj)
return obj
def filter(self, *args, **kwargs):
if (
("payer" in kwargs and "payee" not in kwargs) or
("payee" in kwargs and "payer" not in kwargs)
):
if "payee" in kwargs:
key, other = "payee", "payer"
else:
key, other = "payer", "payee"
constraints = (
models.Q(payer=kwargs[key]) |
models.Q(payee=kwargs[key])
)
queryset = super().filter(constraints)
for obj in queryset:
if getattr(obj, other) == kwargs[key]:
self.do_swap(obj)
return queryset
return super().filter(*args, **kwargs)
class BalanceManager (models.Manager.from_queryset(SymmetricPlayersQuerySet)):
pass
class Balance (models.Model):
objects = BalanceManager()
payer = models.ForeignKey(
Player,
on_delete=models.CASCADE,
related_name='balance_payer',
)
payee = models.ForeignKey(
Player,
on_delete=models.CASCADE,
related_name='balance_payee',
)
amount = models.DecimalField(decimal_places=2, max_digits=1000, default=0)
def do_swap(self):
self.payer, self.payee = self.payee, self.payer
self.amount *= -1
def save(self, *args, **kwargs):
swap = False
if self.payer.id > self.payee.id:
swap = True
self.do_swap()
result = super().save(*args, **kwargs)
if swap:
self.do_swap()
return result
def refresh_from_db(self, *args, **kwargs):
swap = False
if self.payer.id > self.payee.id:
swap = True
super().refresh_from_db(*args, **kwargs)
if swap:
self.do_swap()