Search code examples
sqlalchemy

SqlAlchemy 1.4 ORM - Correct Way to Select filtered by attribute .in_ 2nd selection


We recently updated sqlAlchemy to 1.4, where we use the ORM API.

def get_data(session, target_batch):
    in_scope_bar_ids = session.query(Bar.id).filter(Bar.batch_number == target_batch).subquery()

    query = (
        session.query(
            FooByBar.foo_id.label("FOO_ID"),
            FooByBar.bar_id.label("BAR_ID"),
        )
        .filter(FooByBar.bar_id.in_(in_scope_bar_ids))
    )

    df = pd.read_sql_query(query.selectable, session.get_bind())
    return df

Now triggers warning like: # SAWarning: Coercing Subquery object into a select() for use in IN(); please pass a select() construct explicitly .filter(FooByBar.bar_id.in_(in_scope_bar_ids))

Dropping the .subquery() construct seems to fix it, but I can't find anything in the documentation explaining what changed. What's the proper, readable way to get a group of ids from table Bar and then filter FooByBar by those ids? Thank you!


Solution

  • From the 1.4 docs for subquery():

    A subquery is from a SQL perspective a parenthesized, named construct that can be placed in the FROM clause of another SELECT statement.

    So I think the issue is that is not what you intended here. You just want an unaliased, anonymous, subquery which is just a straight select. Seems you can drop the .subquery() or start using select(). I think another issue is that sqlalchemy has to assume you will return scalars and not something different. Probably just backtracking some old implicit conversions from years ago that are now blocking progress to 2.0+.

    There seems to be a little bit of background in this issue 3d99ee28ed368c3bdbeaf872ef65b0c9a7c0da33

    Examples

    import sys
    from sqlalchemy import (
        create_engine,
        Integer,
    )
    from sqlalchemy.schema import (
        Column,
        ForeignKey,
    )
    from sqlalchemy.sql import select
    from sqlalchemy.orm import declarative_base, Session
    
    
    Base = declarative_base()
    
    
    username, password, db = sys.argv[1:4]
    
    
    engine = create_engine(f"postgresql+psycopg2://{username}:{password}@/{db}", echo=True)
    
    
    class Bar(Base):
        __tablename__ = "bars"
        id = Column(Integer, primary_key=True)
        batch_number = Column(Integer, nullable=False)
    
    class Foo(Base):
        __tablename__ = "foos"
        id = Column(Integer, primary_key=True)
    
    
    class FooByBar(Base):
        __tablename__ = 'foos_bars'
        id = Column(Integer, primary_key=True)
        foo_id = Column(Integer, ForeignKey('foos.id'), nullable=False)
        bar_id = Column(Integer, ForeignKey('bars.id'), nullable=False)
    
    
    Base.metadata.create_all(engine)
    
    
    def get_data(session, target_batch):
        in_scope_bar_ids = session.query(Bar.id).filter(Bar.batch_number == target_batch).subquery()
    
        query = (
            session.query(
                FooByBar.foo_id.label("FOO_ID"),
                FooByBar.bar_id.label("BAR_ID"),
            )
            .filter(FooByBar.bar_id.in_(in_scope_bar_ids))
        )
    
        query.all()
    
    def get_data_no_subquery(session, target_batch):
        in_scope_bar_ids = session.query(Bar.id).filter(Bar.batch_number == target_batch)
    
        query = (
            session.query(
                FooByBar.foo_id.label("FOO_ID"),
                FooByBar.bar_id.label("BAR_ID"),
            )
            .filter(FooByBar.bar_id.in_(in_scope_bar_ids))
        )
    
        query.all()
    
    def get_data_select(session, target_batch):
        in_scope_bar_ids = select(Bar.id).where(Bar.batch_number == target_batch)
    
        query = (
            session.query(
                FooByBar.foo_id.label("FOO_ID"),
                FooByBar.bar_id.label("BAR_ID"),
            )
            .filter(FooByBar.bar_id.in_(in_scope_bar_ids))
        )
    
        query.all()
    
    
    with Session(engine) as session:
        b1 = Bar(batch_number=1)
        b2 = Bar(batch_number=2)
        session.add_all([b1, b2])
        f1 = Foo()
        f2 = Foo()
        session.add_all([f1, f2])
        session.flush()
        f1b1 = FooByBar(foo_id=f1.id, bar_id=b1.id)
        f2b2 = FooByBar(foo_id=f2.id, bar_id=b2.id)
        session.add_all([f1b1, f2b2])
        session.flush()
    
        session.commit()
    
    with Session(engine) as session:
        # WARNING
        get_data(session, 1)
    
        get_data_no_subquery(session, 1)
    
        get_data_select(session, 1)