Search code examples
pythonsqlsqlalchemy

SQLAlchemy 2.x: Eagerly load joined collection query


Context

With SQLAlchemy 2.x, how to eagerly load a joined collection?

Let's say we have the following models Parent and Child:

class Parent(Base):
    __tablename__ = "parent"
    id: Mapped[int] = mapped_column(primary_key=True)
    name: Mapped[str] = mapped_column(String(30))

    children: Mapped[List["Child"]] = relationship(
        back_populates="parent", cascade="all, delete-orphan"
    )

    def __repr__(self) -> str:
        return f"Parent(id={self.id!r}, name={self.name!r})"

class Child(Base):
    __tablename__ = "child"
    id: Mapped[int] = mapped_column(primary_key=True)
    name: Mapped[str] = mapped_column(String(30))

    parent_id: Mapped[int] = mapped_column(ForeignKey("parent.id"))
    parent: Mapped["Parent"] = relationship(back_populates="children")

    def __repr__(self) -> str:
        return f"Child(id={self.id!r}, name={self.name!r})"

And I would like to get the Parent that has a Child with id equal to 1. And populate the result parent.children with all its children.

IOW, if our parent and child tables are populated with:

parent
id  name
1   p1
2   p2
child
id  name  parent_id
1   c1    1
2   c2    1
3   c2    1
4   c3    1
5   c4    2

I would like to see the query result object:

result = <query parent whose children has id == 1>
print(result)
>>> Parent(id=1, name='p1')

print(result.children)
>>> [
        Child(id=1, name='c1'),
        Child(id=2, name='c2'),
        Child(id=3, name='c3'),
        Child(id=4, name='c4'),
    ]

Test case 1

stmt = select(Parent).join(Parent.children).where(Child.id == 1)

Generates the following SQL:

SELECT parent.id, parent.name 
FROM parent JOIN child ON parent.id = child.parent_id 
WHERE child.id = 1

Which looks great, but since I don't tell sqlalchemy to eagerly load children, when accessing the result scalar object (parent.children), I get the error:

sqlalchemy.exc.MissingGreenlet: greenlet_spawn has not been called; can't call await_only() here. Was IO attempted in an unexpected place? (Background on this error at: https://sqlalche.me/e/20/xd2s)

Test case 2

stmt = select(Parent).options(joinedload(Parent.children)).where(Child.id == 1)

Generates the following SQL:

SELECT parent.id, parent.name, child_1.id AS id_1, child_1.name AS name_1, child_1.parent_id 
FROM parent JOIN child AS child_1 ON parent.id = child_1.parent_id, child 
WHERE child.id = 1

Which is not what we're looking for, notice child now appears in the FROM clause.


Solution

  • I was kind of suprised but joining with an alias appears to work but it seems like a footgun for later. I'm kind of a big fan of subqueries and that might be more flexible later on and maybe even faster.

    I did not test this with async but I think it should work.

    2023-05-27 15:44:06,791 INFO sqlalchemy.engine.Engine BEGIN (implicit)
    2023-05-27 15:44:06,809 INFO sqlalchemy.engine.Engine SELECT parents.id, parents.name, childs_1.id AS id_1, childs_1.name AS name_1, childs_1.parent_id 
    FROM parents JOIN childs AS childs_2 ON parents.id = childs_2.parent_id LEFT OUTER JOIN childs AS childs_1 ON parents.id = childs_1.parent_id 
    WHERE childs_2.id = %(id_2)s
    2023-05-27 15:44:06,809 INFO sqlalchemy.engine.Engine [generated in 0.00066s] {'id_2': 1}
    2023-05-27 15:44:06,814 INFO sqlalchemy.engine.Engine ROLLBACK
    2023-05-27 15:44:06,815 INFO sqlalchemy.engine.Engine BEGIN (implicit)
    2023-05-27 15:44:06,817 INFO sqlalchemy.engine.Engine SELECT parents.id, parents.name, childs_1.id AS id_1, childs_1.name AS name_1, childs_1.parent_id 
    FROM parents LEFT OUTER JOIN childs AS childs_1 ON parents.id = childs_1.parent_id 
    WHERE parents.id IN (SELECT childs.parent_id 
    FROM childs 
    WHERE childs.id = %(id_2)s)
    2023-05-27 15:44:06,817 INFO sqlalchemy.engine.Engine [generated in 0.00018s] {'id_2': 1}
    2023-05-27 15:44:06,819 INFO sqlalchemy.engine.Engine ROLLBACK
    
    import sys
    from sqlalchemy import (
        create_engine,
        Integer,
        String,
    )
    from sqlalchemy.schema import (
        Column,
        ForeignKey,
    )
    from sqlalchemy.sql import select
    from sqlalchemy.orm import declarative_base, Session, aliased, relationship, joinedload
    
    
    Base = declarative_base()
    
    
    username, password, db = sys.argv[1:4]
    
    
    engine = create_engine(f"postgresql+psycopg2://{username}:{password}@/{db}", echo=True)
    
    
    class Parent(Base):
        __tablename__ = "parents"
        id = Column(Integer, primary_key=True)
        name = Column(String)
        childs = relationship("Child", back_populates="parent")
    
    
    class Child(Base):
        __tablename__ = 'childs'
        id = Column(Integer, primary_key=True)
        name = Column(String)
        parent_id = Column(Integer, ForeignKey('parents.id'), nullable=False)
        parent = relationship("Parent", back_populates="childs")
    
    
    Base.metadata.create_all(engine)
    
    parents = [line.split()[:2] for line in """
    1   p1
    2   p2
    """.splitlines() if line.strip()]
    
    # Changed child names to be consistent.
    childs = [line.split()[:3] for line in """
    1   c1    1
    2   c2    1
    3   c3    1
    4   c4    1
    5   c5    2
    """.splitlines() if line.strip()]
    
    with Session(engine) as session:
        session.add_all([Parent(id=int(parent[0]), name=parent[1]) for parent in parents])
        session.add_all([Child(id=int(child[0]), name=child[1], parent_id=int(child[2])) for child in childs])
        session.commit()
    
    def test_results(parents):
        assert len(parents) == 1
        assert len(parents[0].childs) == 4
        assert list(sorted([child.id for child in parents[0].childs])) == [1, 2, 3, 4]
    
    # using aliased to prevent join conflict
    with Session(engine) as session:
        c2 = aliased(Child)
        q = select(Parent).join(c2, Parent.id == c2.parent_id).where(c2.id == 1).options(joinedload(Parent.childs))
        parents = [parent for parent in session.scalars(q).unique()]
        test_results(parents)
    
    # subquery
    with Session(engine) as session:
        parents = list(session.scalars(select(Parent).where(Parent.id.in_(select(Child.parent_id).where(Child.id == 1))).options(joinedload(Parent.childs))).unique())
        test_results(parents)