Search code examples
flasksqlalchemyflask-sqlalchemy

Flask-SQLAlchemy Multiple filters on Parent and Child


I have the following parent class:

class Pairing(Base):
    __tablename__ = 'pairing'
    id: Mapped[int] = mapped_column(Integer, unique=True, primary_key=True, autoincrement=True)
    pairing_no: Mapped[str] = mapped_column(String(10), unique=False, nullable=False)
    total_expenses: Mapped[str] = mapped_column(String(10), unique=False, nullable=True)
    pairing_tafb: Mapped[str] = mapped_column(String(10), unique=False, nullable=True)

    flights: Mapped[List["Flight"]] = relationship(
        back_populates="pairing", cascade='all, delete'
    )

And the following child class:

class Flight(Base):
    __tablename__ = 'flight'
    id: Mapped[int] = mapped_column(Integer, unique=True, primary_key=True, autoincrement=True)
    destination_station: Mapped[str] = mapped_column(String, unique=False)

    pairing_no: Mapped[str] = mapped_column(ForeignKey('pairing.pairing_no'), unique=False)
    pairing: Mapped["Pairing"] = relationship(back_populates='flights')

I can successfully apply multiple filters on the Pairing class using the following:

pairing_filter = {
        'pairing_no': request.form.get('search_pairing'),
        'pairing_tafb': request.form.get('pairing_tafb'),
        'total_expenses': request.form.get('total_expenses'),
    }
    pairing_filter = {key: value for (key, value) in pairing_filter.items() if value}
pairings_to_display = db.session.scalars(select(Pairing).filter_by(**pairing_filter).order_by(Pairing.id)).all()

but I want to apply more filters that filter the Flight class. For example, I want to further filter the pairings_to_display result to only show pairings that contain flights that land in LHR. Any combination of filters could apply, based on user input.

How can I achieve this?


Solution

  • To filter the results based on a property of the referenced model, it is necessary to join using the ForeignKey beforehand.

    stmt = db.select(Pairing) \
        .join(Flight) \
        .order_by(Pairing.id) \
        .group_by(Pairing.id)
    key_bindings = { 
        'pairing_tafb': Pairing.pairing_tafb, 
        'search_pairing': Pairing.pairing_no,
        'total_expenses': Pairing.total_expenses,
        'dest': Flight.destination_station, 
    }
    for k,v in key_bindings.items():
        if val := request.form.get(k):
            stmt = stmt.where(v == val)
    results = db.session.scalars(stmt).all()