Search code examples
pythonpostgresqlsqlalchemygreatest-n-per-group

How to limit N results per `group_by` in SQLAlchemy/Postgres?


This is my SQLAlchemy query code

medium_contact_id_subq = (g.session.query(distinct(func.unnest(FUContact.medium_contact_id_lis))).filter(FUContact._id.in_(contact_id_lis))).subquery()
q = (g.session.query(FUMessage).
         filter(FUMessage.fu_medium_contact_id.in_(medium_contact_id_subq))
         .order_by(desc(FUMessage.timestamp_utc))
         )

I'd like to limit FUMessage grouped by medium_contact_id with N results.


As a workaround, this is my current ugly and unoptimized code:

    medium_contact_id_lis = (g.session.query(distinct(func.unnest(FUContact.medium_contact_id_lis))).filter(FUContact._id.in_(contact_id_lis))).all()
    q = None
    for medium_contact_id_tup in medium_contact_id_lis:
        medium_contact_id = medium_contact_id_tup[0]
        if q is None:
            q = (g.session.query(FUMessage)
                 .filter(FUMessage.fu_medium_contact_id == medium_contact_id)
                 .limit(MESSAGE_LIMIT)
                 )
        else:
            subq = (g.session.query(FUMessage)
                 .filter(FUMessage.fu_medium_contact_id == medium_contact_id)
                 .limit(MESSAGE_LIMIT)
                 )
            q = q.union(subq)
    q = q.order_by(desc(FUMessage.timestamp_utc))

Solution

  • One way to fetch top N rows per group is to use a window function such as rank() or row_number() in a subselect with required grouping and order and then filter by that in the enclosing select. For N = 1 you could use the DISTINCT ON ... ORDER BY combination in Postgresql.

    Adopting that to SQLAlchemy is straightforward using function element's over() method to produce a window expression:

    medium_contact_id_subq = g.session.query(
            func.unnest(FUContact.medium_contact_id_lis).distinct()).\
        filter(FUContact._id.in_(contact_id_lis)).\
        subquery()
    
    # Perform required filtering in the subquery. Choose a suitable ordering,
    # or you'll get indeterminate results.
    subq = g.session.query(
            FUMessage,
            func.row_number().over(
                partition_by=FUMessage.fu_medium_contact_id,
                order_by=FUMessage.timestamp_utc).label('n')).\
        filter(FUMessage.fu_medium_contact_id.in_(medium_contact_id_subq)).\
        subquery()
    
    fumessage_alias = aliased(FUMessage, subq)
    
    # row_number() counts up from 1, so include rows with a row num
    # less than or equal to limit
    q = g.session.query(fumessage_alias).\
        filter(subq.c.n <= MESSAGE_LIMIT)