Search code examples
djangodjango-modelsormdjango-mptt

Annotate Total Count of Descents of Mptt Model


Question

Given the models below, I want to get a queryset of all pages, annotated with the total number of comments in the thread associated to the page, including all comments in a comment thread tree associated with pages.

I am using django-mptt to store comment tree.

I can get this in python using comment.get_descendant_count(), but this is very ineficient when querying all pages

Models

class CommentThread(models.Model):
    ...


class Page(models.Model):
    ...
    thread = models.ForeignKey("CommentThread", ...)   



class Comment(MPTTModel):
    body = models.TextField()
    author = models.ForeignKey("User", ...)

    # Validation ensures a comment has
    # a thread or parent comment but not both
    thread = models.ForeignKey("CommentThread", related_name="comments", ...)
    parent = TreeForeignKey(
                'self',
                on_delete=models.CASCADE,
                null=True,
                blank=True,
                related_name='children'
    )

    class MPTTMeta:
        level_attr = 'mptt_level'
        order_insertion_by=['name']

This model allows me to add multiple "root comments" to page, but also nest comments under each comment as replies, recursively.


# Model Use Examples

thread = CommentThread()
page = Page(thread=thread)

# add page level root comments
comment1 = Comment(thread=thread, ...)
comment2 = Comment(thread=thread, ...)
comment3 = Comment(thread=thread, ...)

# add Comment Replies to comment #1
comment_reply1 = Comment(parent=comment1, ...)
comment_reply2 = Comment(parent=comment1, ...)
comment_reply3 = Comment(parent=comment1, ...)

Current approach - in python

Works but very inefficient:

page = Page.objects.first()
total_comments = [c.get_descendant_count() for c in page.thread.comments.all()]

What I have tried

I am not sure how to achieve this with querysets and annotations. I know each mptt model also get a treed_id, so I am guessing I would need to build a more complex subquery.

To get the number of root comments only (not including nested), I could do it like this:

pages = Page.objects.all().annotate(num_comments=models.Count("thread__comments"))
num_root_comments = pages[0].num_comments

Once again, the goal is to get all comments, including nested:

# Non working code - this kind of  pseudo queryset code of what I am trying:

all_page_comments = Comment.objects.filter(tree_id__in= (Page.thread__comments__tree_id))
Page.objects.all().annotate(num_comments=Count(Subquery(all_page_comments))

Thanks in advance for any help provided.

Solution

Got a working solution thanks to @andrey's answer below. Not sure it's optimal but seems to return the correct values in a single query.

threads = CommentThread.objects.filter(
        id=models.OuterRef("thread")
    ).annotate(
        comment_count=models.Sum(
            Floor((models.F("comments__rght") - models.F("comments__lft") - 1) / 2)
        )
    )

qs_pages_with_comment_count = (
    Page.objects
    .annotate(
        comment_count=models.Subquery(
            threads.values("comment_count")[:1], output_field=models.IntegerField()
        )
    )
    # Count from subquery included count of descendents of 
    # each "root" comment but not the root comment itself
    # so we add  number of root comments per thread on top
    .annotate(
        comment_count=models.F("comment_count")
        + models.Count("thread__comments", distinct=True)
    )
)

Solution

  • queryset.annotate(
        descendants_count=Floor((F('rght') - F('lft') - 1) / 2)
    ).values(
        'descendants_count'
    ).aggregate(
        total_count=Count('descendants_count')
    )
    

    Let me explain

    First, current method of get_descendant_count just operates existing data, so we can use it in Queryset.

    def get_descendant_count(self):
        """
        Returns the number of descendants this model instance has.
        """
        if self._mpttfield('right') is None:
            # node not saved yet
            return 0
        else:
            return (self._mpttfield('right') - self._mpttfield('left') - 1) // 2
    

    This is the current mptt models' method. In queryset we are sure that all of instances is already saved so we'll skip that.

    Next step is to transform math operations into db expressions. In Django 3.0 appeared Floor expression. But we can use it even in 1.7 (as I do)

    from django.db.models.lookups import Transform
    
    class Floor(Transform):
         function = 'FLOOR'
         lookup_name = 'floor'
    

    If you want you can refactor this to use self._mpttfield('right') analog instead of hardcoded rght, lftand make this as Manager method

    Let's test. I have top element with descendants

    In [1]: m = MenuItem.objects.get(id=settings.TOP_MENU_ID)
    
    In [2]: m.get_descendant_count()
    Out[2]: 226
    
    In [3]: n = m.get_descendants()
    
    In [4]: n.annotate(descendants_count=Floor((F('rght') - F('lft') - 1) / 2)).values('descendants_count').aggregate(total_count=Count('descendants_count'))
    Out[4]: {'total_count': 226}