Search code examples
pythonsqlalchemymany-to-manyrelationship

How do I build a SQLAlchemy many-to-many relationship on a single table with extra data in the related association table?


I would like to build a many-to-many relationship between instances of the same class. (An object can be composed of 0 to many objects of the same type. The object itself can be contained in 0 or many other objects of the same type.)

I also would like to add extra data to each record in the resulting association table.

With respect to the many-to-many relationship within the same table I found this SO answer (https://stackoverflow.com/a/5652169/3006060) and produced the following test code:

from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import ForeignKey
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.orm import sessionmaker


SQLALCHEMY_DB_URL: str = "sqlite:///./loadcases.sqlite"


class Base(DeclarativeBase):
    pass


load_case_sub_load_association = Table(
    "load_case_sub_load_association",
    Base.metadata,
    Column("load_case_id", String, ForeignKey("load_cases.id")),
    Column("sub_load_case_id", String, ForeignKey("load_cases.id"))
)


class LoadCase(Base):
    __tablename__ = "load_cases"

    id: Mapped[int] = mapped_column(primary_key=True, index=True)
    title: Mapped[str] = mapped_column(index=True)
    load_cases = relationship('LoadCase',
                              secondary="load_case_sub_load_association",
                              backref='sub_load_cases',
                              primaryjoin=load_case_sub_load_association.c.load_case_id==id,
                              secondaryjoin=load_case_sub_load_association.c.sub_load_case_id==id
                              )


if __name__ == '__main__':
    engine = create_engine(
        SQLALCHEMY_DB_URL,
        connect_args={"check_same_thread": False}  # needed for sqlite, remove when using other dbs
    )
    Base.metadata.create_all(bind=engine)
    Session = sessionmaker(engine)  # keep Session in same scope as engine!

    with Session() as session:
        ulc_t = LoadCase(title='ulc_t')
        ulc_p = LoadCase(title='ulc_p')
        ulc_m = LoadCase(title='ulc_m')
        session.add_all([ulc_t, ulc_p, ulc_m])
        clc = LoadCase(title='clc', load_cases=[ulc_t, ulc_p, ulc_m])
        session.add(clc)
        session.commit()

This generates the following tables:

 +------------+    +---------------------------------+ 
 | load_cases |    | load_case_sub_load_association  | 
 +----+-------+    +--------------+------------------+ 
 | id | title |    | load_case_id | sub_load_case_id | 
 +----+-------+    +--------------+------------------+ 
 | 1  | ulc_t |    | 4            | 1                | 
 | 2  | ulc_p |    | 4            | 2                | 
 | 3  | ulc_m |    | 4            | 3                | 
 | 4  | clc   |    +--------------+------------------+ 
 +---+--------+                                            

So far so good.

However, I now want to add an extra column to the association table containing a float value, for example:

+------------------------------------------+
|      load_case_sub_load_association      |
+--------------+------------------+--------+
| load_case_id | sub_load_case_id | factor |
+--------------+------------------+--------+
| 4            | 1                | 1.5    |
| 4            | 2                | 1.0    |
| 4            | 3                | -2.7   |
+--------------+------------------+--------+

I added a column to the association table like this:

load_case_sub_load_association = Table(
    "load_case_sub_load_association",
    Base.metadata,
    Column("load_case_id", String, ForeignKey("load_cases.id")),
    Column("sub_load_case_id", String, ForeignKey("load_cases.id")),
    Column("factor", Float)  # <-- new column
)

This produces the correct table layout but the factor is empty of cause (NULL).

I don't know how to amend the relationship('LoadCase', ...) to be able to add values to that column ...

At some point I read that I need to use an association object and I found an example (https://stackoverflow.com/a/62378982/3006060), but that actually constructed a many-to-many relationship between different object types, and I was not able to bring the two solutions together ...

How do I build a SQLAlchemy many-to-many relationship on a single table with extra data in the related association table? How would I feed in the extra data when the association is built and how would I retrieve the data again from the DB?

PS: I am not baking the factor into the LoadCase object itself, because a loadcase can pop up in different scenarios with different factors!


Solution

  • I think if there was a better name for the association and/or relationships maybe this configuration would be less horrifying but this seems to be working. (I edited answer to improve names but maybe calling middle class links would be better). Obviously even at the SQL level this sort of many to many is not going to perform very well if you need to traverse variable depth relationships.

    If you don't always need the factor you can create additional view_only relationships that bypass the AssociatedLoadClass.

    
    from sqlalchemy import Column
    from sqlalchemy import create_engine
    from sqlalchemy import ForeignKey
    from sqlalchemy import String
    from sqlalchemy import Float, Integer
    from sqlalchemy import Table
    from sqlalchemy.sql import select
    from sqlalchemy.orm import declarative_base
    from sqlalchemy.orm import Mapped
    from sqlalchemy.orm import mapped_column
    from sqlalchemy.orm import relationship
    from sqlalchemy.orm import Session, selectinload, joinedload
    
    
    SQLALCHEMY_DB_URL: str = "sqlite:///./loadcases.sqlite"
    
    
    Base = declarative_base()
    
    
    # Tell SQLAlchemy which foreign key goes to which relationship on BOTH sides
    # of the relationship.
    child_relationship_kwargs = dict(
        foreign_keys="LoadCaseAssociation.child_load_case_id")
    parent_relationship_kwargs = dict(
        foreign_keys="LoadCaseAssociation.parent_load_case_id")
    
    
    class LoadCaseAssociation(Base):
        __tablename__ = "load_case_associations"
        # Can repeats occur?  If so then you probably need an id column.
        parent_load_case_id = Column(Integer, ForeignKey("load_cases.id"), primary_key=True)
        child_load_case_id = Column(Integer, ForeignKey("load_cases.id"), primary_key=True)
    
        child_load_case = relationship('LoadCase', back_populates='associations_as_child', **child_relationship_kwargs)
    
        parent_load_case = relationship('LoadCase', back_populates='associations_as_parent', **parent_relationship_kwargs)
    
        factor = Column(Float, nullable=True)
    
    
    class LoadCase(Base):
        __tablename__ = "load_cases"
    
        id = Column(Integer, primary_key=True, index=True)
        title = Column(String, index=True)
    
        associations_as_child = relationship('LoadCaseAssociation', back_populates='child_load_case', **child_relationship_kwargs)
    
        associations_as_parent = relationship('LoadCaseAssociation', back_populates='parent_load_case', **parent_relationship_kwargs)
    
    if __name__ == '__main__':
        engine = create_engine(
            SQLALCHEMY_DB_URL,
            echo=True,
            connect_args={"check_same_thread": False}  # needed for sqlite, remove when using other dbs
        )
        Base.metadata.create_all(bind=engine)
    
        with Session(engine) as session:
            ulc_t = LoadCase(title='ulc_t')
            ulc_p = LoadCase(title='ulc_p')
            ulc_m = LoadCase(title='ulc_m')
            session.add_all([ulc_t, ulc_p, ulc_m])
            clc = LoadCase(title='clc')
            for index, child in enumerate([ulc_t, ulc_p, ulc_m]):
                assoc = LoadCaseAssociation(child_load_case=child, factor=index)
                session.add(assoc)
                clc.associations_as_parent.append(assoc)
            session.add(clc)
            session.commit()
    
        # Check that ulc_t is actual a child of clc.
        with Session(engine) as session:
            clc = session.execute(select(LoadCase).where(LoadCase.title == 'clc')).scalar()
            child_load_case_titles = list(sorted([assoc.child_load_case.title for assoc in clc.associations_as_parent]))
            assert len(child_load_case_titles) == 3
            assert 'ulc_t' in child_load_case_titles
    
        # Check that clc is actually a parent of ulc_t and check that the correct
        # factor is set.
        with Session(engine) as session:
            ulc_t = session.scalar(select(LoadCase).where(LoadCase.title == 'ulc_t'))
            assoc = list(ulc_t.associations_as_child)[0]
            assert len(list(ulc_t.associations_as_child)) == 1
            assert assoc.factor == float(0)
            assert assoc.parent_load_case.title == 'clc'
    
        # Create more associations, this time separately instead of using insert on parent relationship.
        with Session(engine) as session:
            a_a = LoadCase(title='a_a')
            a_b = LoadCase(title='a_b')
            a_c = LoadCase(title='a_c')
            session.add_all([a_a, a_b, a_c])
            a = LoadCase(title='a')
            session.add(a)
            for index, child in enumerate([a_a, a_b, a_c]):
                session.add(LoadCaseAssociation(parent_load_case=a, child_load_case=child, factor=index))
            session.commit()
    
        # Load parent, associations and children of associations all using single query.
        with Session(engine) as session:
    
            a = session.scalar(select(LoadCase).where(LoadCase.title == 'a').options(joinedload(LoadCase.associations_as_parent).joinedload(LoadCaseAssociation.child_load_case)))
            # Check log (echo=True) only generates single query.
            assert 'a_a' in [association.child_load_case.title for association in a.associations_as_parent]