Search code examples
pythonsqlalchemymany-to-many

SQLAlchemy Many-to-Many Relationship: UNIQUE constraint failed


So, I have a many to many SQLAlchemy relationship defined likeso,

from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy import Column, Integer, String, ForeignKey, UniqueConstraint, Table, create_engine
from sqlalchemy.orm import relationship, registry


mapper_registry = registry()
Base = declarative_base()


bridge_category = Table(
    "bridge_category",
    Base.metadata,
    Column("video_id", ForeignKey("video.id"), primary_key=True),
    Column("category_id", ForeignKey("category.id"), primary_key=True),
    UniqueConstraint("video_id", "category_id"),
)
class BridgeCategory: pass
mapper_registry.map_imperatively(BridgeCategory, bridge_category)


class Video(Base):
    __tablename__ = 'video'

    id = Column(Integer, primary_key=True)
    title = Column(String)
    categories = relationship("Category", secondary=bridge_category, back_populates="videos")


class Category(Base):
    __tablename__ = 'category'

    id = Column(Integer, primary_key=True)
    text = Column(String, unique=True)
    videos = relationship("Video", secondary=bridge_category, back_populates="categories")


engine = create_engine('sqlite:///:memory:', echo=True)
Base.metadata.create_all(engine)

Session = sessionmaker(bind=engine)

with Session() as s:

    v1 = Video(title='A', categories=[Category(text='blue'), Category(text='red')])
    v2 = Video(title='B', categories=[Category(text='green'), Category(text='red')])
    v3 = Video(title='C', categories=[Category(text='grey'), Category(text='red')])
    videos = [v1, v2, v3]

    s.add_all(videos)
    s.commit()

Of course, because of the unique constraint on Category.text, we get the following error.

sqlalchemy.exc.IntegrityError: (sqlite3.IntegrityError) UNIQUE constraint failed: category.text
[SQL: INSERT INTO category (text) VALUES (?) RETURNING id]
[parameters: ('red',)]

I am wondering what the best way of dealing with this is. With my program, I get a lot of video objects, each with a list of unique Category objects. The text collisions happen across all these video objects.

I could loop through all videos, and all categories, forming a Category set, but that's kinda lame. I'd also have to do that with the 12+ other many-to-many relationships my Video object has, and that seems really inefficient.

Is there like a "insert ignore" flag I can set for this? I haven't been able to find anything online concerning this situation.


Solution

  • With a lot of help from the maintainer of SQLAlchemy, I came up with a generic implementation of the code that requires hardly any configurations, or repeating steps, for a single SA model object that contains multiple many-to-many relationships.

    from sqlalchemy import Column
    from sqlalchemy import create_engine
    from sqlalchemy import event
    from sqlalchemy import ForeignKey
    from sqlalchemy import inspect
    from sqlalchemy import Integer
    from sqlalchemy import select
    from sqlalchemy import String
    from sqlalchemy import Table
    from sqlalchemy import UniqueConstraint
    from sqlalchemy.orm import declarative_base
    from sqlalchemy.orm import registry
    from sqlalchemy.orm import relationship
    from sqlalchemy.orm import RelationshipDirection
    from sqlalchemy.orm import Session
    from sqlalchemy.orm import sessionmaker
    
    
    mapper_registry = registry()
    Base = declarative_base()
    
    
    bridge_category = Table(
        "bridge_category",
        Base.metadata,
        Column("video_id", ForeignKey("video.id"), primary_key=True),
        Column("category_id", ForeignKey("category.id"), primary_key=True),
        UniqueConstraint("video_id", "category_id"),
    )
    class BridgeCategory: pass
    mapper_registry.map_imperatively(BridgeCategory, bridge_category)
    
    
    bridge_format = Table(
        "bridge_format",
        Base.metadata,
        Column("video_id", ForeignKey("video.id"), primary_key=True),
        Column("format_id", ForeignKey("format.id"), primary_key=True),
        UniqueConstraint("video_id", "format_id"),
    )
    class BridgeFormat: pass
    mapper_registry.map_imperatively(BridgeFormat, bridge_format)
    
    
    class Video(Base):
        __tablename__ = "video"
    
        id = Column(Integer, primary_key=True)
        title = Column(String)
        categories = relationship("Category", secondary=bridge_category, back_populates="videos")
        formats = relationship("Format", secondary=bridge_format, back_populates="videos")
    
    class Category(Base):
        __tablename__ = "category"
    
        id = Column(Integer, primary_key=True)
        text = Column(String, unique=True)
        videos = relationship("Video", secondary=bridge_category, back_populates="categories")
    
    class Format(Base):
        __tablename__ = "format"
        id = Column(Integer, primary_key=True, index=True)
        text = Column(String, unique=True)
        videos = relationship("Video", back_populates="formats", secondary=bridge_format)
    
    
    def unique_robs(session_or_factory, main_obj, rob_unique_col):
        """Unique related objects"""
    
        def _unique_robs(session, robs, rob_name):
            if not robs:
                return robs
            
            rob_type = type(robs[0])
    
            with session.no_autoflush:
                local_existing_robs = session.info.get(rob_name, None)
                if local_existing_robs is None:
                    session.info[rob_name] = local_existing_robs = {}
    
                unique_vals = []
                for r in robs:
                    unique_val = getattr(r, rob_unique_col)
                    if unique_val not in local_existing_robs:
                        unique_vals.append(unique_val)
    
                existing_categories = {}
                unique_col = getattr(rob_type, rob_unique_col)
                for r in session.scalars(select(rob_type).where(unique_col.in_(unique_vals))):
                    existing_categories[getattr(r, rob_unique_col)] = r
    
                local_existing_robs.update(existing_categories)
    
                result = []
                for r in robs:
                    if getattr(r, rob_unique_col) in local_existing_robs:
                        result.append(local_existing_robs[getattr(r, rob_unique_col)])
                        continue
    
                    local_existing_robs[getattr(r, rob_unique_col)] = r
                    result.append(r)
    
                return result
    
        @event.listens_for(session_or_factory, "before_attach", retval=True)
        def before_attach(session, obj):
            """Uniquifies all `main_obj` many-to-many relationships."""
            if isinstance(obj, main_obj):
                for r in inspect(obj).mapper.relationships:
                    if r.direction.value == RelationshipDirection.MANYTOMANY.value:
                        rob_name = r.class_attribute.key
                        robs = getattr(obj, rob_name, None)
                        if isinstance(robs, list):
                            setattr(obj, rob_name, _unique_robs(session, robs, rob_name))
    
    
    if __name__ == "__main__":
        engine = create_engine("sqlite:///test.db", echo=True)
        Base.metadata.create_all(engine)
        Session = sessionmaker(bind=engine)
        unique_robs(Session, Video, 'text')
    
        v1 = Video(title="A", categories=[Category(text="blue"), Category(text="red")])
        v2 = Video(title="B", categories=[Category(text="green"), Category(text="red")], formats=[Format(text='h264')])
        v3 = Video(title="C", categories=[Category(text="grey"), Category(text="red")], formats=[Format(text='h264'), Format(text='vp9')])
        videos = [v1, v2, v3]
    
        with Session() as s:
    
            s.add_all(videos)
            s.commit()