Search code examples
pythonsqlalchemysqlalchemy-utils

SQLAlchemy-Utils Aggregated Attributes: How to apply filter before aggregating to create an aggregated field?


from sqlalchemy_utils import aggregated


class Thread(Base):
    __tablename__ = 'thread'
    id = sa.Column(sa.Integer, primary_key=True)
    name = sa.Column(sa.Unicode(255))

    @aggregated('comments', sa.Column(sa.Integer))
    def comment_count(self):
        return sa.func.count('1')

    comments = sa.orm.relationship(
        'Comment',
        backref='thread'
    )


class Comment(Base):
    __tablename__ = 'comment'
    id = sa.Column(sa.Integer, primary_key=True)
    content = sa.Column(sa.UnicodeText)
    thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id))
    active = sa.Column(sa.Boolean)

I want the comment_count field to be count of ONLY active comments, instead of ALL comments. Is it possible?

The example is based off of the documentation: https://sqlalchemy-utils.readthedocs.io/en/latest/aggregates.html


Solution

  • You need declerative_base and sa in order to be able to do that which you can find more information about them here.

    from sqlalchemy_utils import aggregated
    import sqlalchemy as sa
    from sqlalchemy.ext.declarative import declarative_base
    
    Base = declarative_base()
    
    class Thread(Base):
        __tablename__ = 'thread'
        id = sa.Column(sa.Integer, primary_key=True)
        name = sa.Column(sa.Unicode(255))
    
        @aggregated('comments', sa.Column(sa.Integer))
        def comment_count(self):
            return sa.func.count(sa.case([(Comment.active == True, 1)]))
    
        comments = sa.orm.relationship(
            'Comment',
            backref='thread'
        )
    
    class Comment(Base):
        __tablename__ = 'comment'
        id = sa.Column(sa.Integer, primary_key=True)
        content = sa.Column(sa.UnicodeText)
        thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id))
        active = sa.Column(sa.Boolean)
    
    • sa.case([(Comment.active == True, 1)]) creates a case statement that returns 1 if the comment is active and NULL otherwise.

    • sa.func.count then counts only the rows where the case statement returns 1, effectively counting only active comments. This should give you the comment_count as the number of active comments associated with each thread.