Search code examples
sqlalchemy

SQL Alchemy column_property to check whether value is in set


I have a model using sql alchemy and a set:

NON_RE_RUNNABLE_PIPELINES = {'pipe2', 'pipe1'}
class PipelineStatus(Base):
    __tablename__ = "pipeline_status"

    id = Column(UUID(as_uuid=True), primary_key=True, unique=True, nullable=False)
    pipeline = Column(String)
    re_runnable = column_property(pipeline not in NON_RE_RUNNABLE_PIPELINES)

My intention was that re_runnable field would be True if the pipeline field is contained in the set, False otherwise. But it's proving to be incredibly difficult to do it.

When I query the code above, I get AttributeError: 'bool' object has no attribute '_deannotate'

If I try to use re_runnable = column_property(literal(pipeline not in NON_RE_RUNNABLE_PIPELINES)) instead, it also doesn't work, as it will check whether the NAME of the field is in the set and this is not what I want.

It seems the python expression I put inside column_property is not really executed on raw values, but instead it's converted into an expression with existing columns, which makes it hard for me to use it with a native set.

How can I implement such logic in my model?


Solution

  • You need to use re_runnable = column_property(pipeline.in_(NON_RE_RUNNABLE_PIPELINES)). From the output it can be seen that the filter works.

    If you want the other way round you can do re_runnable = column_property(pipeline.not_in(NON_RE_RUNNABLE_PIPELINES))

    Complete example as follows

    from uuid import uuid4
    from sqlalchemy import create_engine, select, Column, String
    from sqlalchemy.orm import column_property, declarative_base, Session
    from sqlalchemy.dialects.postgresql import UUID
    
    Base = declarative_base()
    
    NON_RE_RUNNABLE_PIPELINES = {"pipe2", "pipe1"}
    
    class PipelineStatus(Base):
        __tablename__ = "pipeline_status"
        id = Column(UUID(as_uuid=True), primary_key=True, unique=True, nullable=False)
        pipeline = Column(String)
        re_runnable = column_property(pipeline.in_(NON_RE_RUNNABLE_PIPELINES))
    
    engine = create_engine('connection_string')
    
    Base.metadata.drop_all(engine)
    Base.metadata.create_all(engine)
    
    with Session(engine) as session:
        session.add(PipelineStatus(id=uuid4(), pipeline="pipe1"))
        session.add(PipelineStatus(id=uuid4(), pipeline="pipe2"))
        session.add(PipelineStatus(id=uuid4(), pipeline="pipe3"))
        session.add(PipelineStatus(id=uuid4(), pipeline="pipe4"))
        session.commit()
    
    with Session(engine) as session:
        for i in session.scalars(select(PipelineStatus).where(PipelineStatus.re_runnable.is_(True))):
            print(i.pipeline, i.re_runnable)
    

    Output

    pipe1 True
    pipe2 True