Search code examples
joinselectsqlalchemymany-to-many

How to avoid summing duplicates in SQLAlchemy with many-to-many relationships?


I have 3 SQLAlchemy tables: TagGroup, Tag and Video. TagGroup/Tag and Tag/Video both have bi-directional many-to-many relationships with each other, specified with an association table. I want to construct a query that sums Video.viewCount field for all videos given some TagGroup.id. However, if the same video is linked to 2 separate Tag tables that are inside the same TagGroup, Video.viewCount will get summed twice. Would appreciate any help to avoid this.

E.g.

                  Video1
                /
           Tag1
         /      \
 TagGroup         Video2 (will get summed twice)
         \      /
           Tag2
                \
                  Video3

NB I have remove all unrelated fields for simplicity.

Tag Group Model:

class TagGroup(Base):
    __tablename__ = "tag_groups"

    id: Mapped[int] = mapped_column(
        Integer,
        primary_key=True,
        autoincrement=True,
    )

    tags = relationship(
        "Tag",
        secondary=tags_and_groups_association_table,
        back_populates="groups",
    )

Tag Group - Tag Association Table:

tags_and_groups_association_table = Table(
    "tags_and_groups_association_table",
    Base.metadata,
    Column("tags_id", ForeignKey("tags.id"), primary_key=True),
    Column("tag_groups_id", ForeignKey("tag_groups.id"), primary_key=True),
    PrimaryKeyConstraint('tags_id', 'tag_groups_id')  # to avoid duplicates.
)

Tag Model:

class Tag(Base):
    __tablename__ = "tags"

    id: Mapped[int] = mapped_column(
        Integer,
        primary_key=True,
        autoincrement=True,
    )

    in_videos = relationship("Video", secondary=video_tags, back_populates="tags",lazy="select")

    groups = relationship(
        "TagGroup",
        secondary=tags_and_groups_association_table,
        back_populates="tags",
    )

Tag - Video Association Table:

video_tags = Table(
    'video_tags', Base.metadata,
    Column('tag_id', Integer, ForeignKey('tags.id'), primary_key=True),
    Column('video_id', Integer, ForeignKey('videos.id'), primary_key=True)
)

Video Model:

class Video(Base):
    __tablename__ = "videos"

    id: Mapped[int] = mapped_column(
        Integer,
        primary_key=True,
        autoincrement=True,
    )
    viewCount: Mapped[int] = mapped_column(
        BigInteger,
        nullable=False,
        default=0,
    )

    tags = relationship("Tag", secondary=video_tags, back_populates="in_videos", lazy="select")

I have tried various combinations with distinct or group_by video.id, but nothing seems to work.

select(
    TagGroup.id,
    func.SUM(Video.viewCount),
    func.COUNT(distinct(Video.id)),  # count works perfectly well with distinct() by unique id.
).select_from(
    TagGroup.id
).where(
    TagGroup.id.in_(ids) # list of TagGroup ids 
).join(
    tags_and_groups_association_table,
    TagGroup.id == tags_and_groups_association_table.c.tag_groups_id
).join(
    Tag,
    tags_and_groups_association_table.c.tags_id == Tag.id
).join(
    video_tags_model,
    Tag.id == video_tags_model.c.tag_id,
).join(
    Video,
    video_tags_model.c.video_id == Video.id
).group_by(
    TagGroup.id,
)

Solution

  • It seems like it works, but try on more test cases.

    If you confirm that it works well, I will add more comments to my answer a bit later.

    from sqlalchemy import (
        Integer, ForeignKey, Column, PrimaryKeyConstraint, Table, BigInteger, create_engine, select, func
    )
    from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase, sessionmaker
    
    
    class Base(DeclarativeBase):
        pass
    
    
    tags_and_groups_association_table = Table(
        "tags_and_groups_association_table",
        Base.metadata,
        Column("tags_id", ForeignKey("tags.id"), primary_key=True),
        Column("tag_groups_id", ForeignKey("tag_groups.id"), primary_key=True),
        PrimaryKeyConstraint('tags_id', 'tag_groups_id')  # to avoid duplicates.
    )
    
    
    class TagGroup(Base):
        __tablename__ = "tag_groups"
    
        id: Mapped[int] = mapped_column(
            Integer,
            primary_key=True,
            autoincrement=True,
        )
    
        tags = relationship(
            "Tag",
            secondary=tags_and_groups_association_table,
            back_populates="groups",
        )
    
    
    
    video_tags = Table(
        'video_tags', Base.metadata,
        Column('tag_id', Integer, ForeignKey('tags.id'), primary_key=True),
        Column('video_id', Integer, ForeignKey('videos.id'), primary_key=True)
    )
    
    
    class Tag(Base):
        __tablename__ = "tags"
    
        id: Mapped[int] = mapped_column(
            Integer,
            primary_key=True,
            autoincrement=True,
        )
    
        in_videos = relationship("Video", secondary=video_tags, back_populates="tags",lazy="select")
    
        groups = relationship(
            "TagGroup",
            secondary=tags_and_groups_association_table,
            back_populates="tags",
        )
    
    
    class Video(Base):
        __tablename__ = "videos"
    
        id: Mapped[int] = mapped_column(
            Integer,
            primary_key=True,
            autoincrement=True,
        )
        viewCount: Mapped[int] = mapped_column(
            BigInteger,
            nullable=False,
            default=0,
        )
    
        tags = relationship("Tag", secondary=video_tags, back_populates="in_videos", lazy="select")
    
    
    
    engine = create_engine("sqlite://")
    Base.metadata.create_all(engine)
    session_maker = sessionmaker(bind=engine)
    
    
    with session_maker() as session:
        t_gr_1 = TagGroup()
    
        t_gr_2 = TagGroup()
    
        t_1 = Tag()
        t_2 = Tag()
        t_3 = Tag()
        t_4 = Tag()
    
        t_gr_1.tags.append(t_1)
        t_gr_1.tags.append(t_2)
    
        t_gr_2.tags.append(t_3)
        t_gr_2.tags.append(t_4)
    
        v_1 = Video(viewCount=111, tags=[t_1, t_2])
        v_2 = Video(viewCount=222, tags=[t_1, t_3])
        v_3 = Video(viewCount=444, tags=[t_3, t_4])
        v_4 = Video(viewCount=444, tags=[t_3])
    
        session.add(t_gr_1)
        session.add(t_gr_2)
        session.commit()
    
    
        ids = [1, 2]
    
        subq = (
            select(Video.viewCount.label("viewCount"), TagGroup.id.label("taggroup_id"))
                .select_from(TagGroup)
                .join(tags_and_groups_association_table)
                .join(Tag)
                .join(video_tags)
                .join(Video)
                .group_by(TagGroup.id, Video.id).subquery()
        )
    
    
        st = (
            select(subq.c.taggroup_id, func.SUM(subq.c.viewCount))
                .select_from(subq)
                .group_by(subq.c.taggroup_id)
        )
    
        for row in session.execute(st).all():
            print(row)
    
    

    Output:

    (1, 333)
    (2, 1110)