Search code examples
pythonsqlalchemyalembic

How to interrogate/alter a database sequence in python?


I'm using alembic and sqlalchemy to work with different database types. I use an own create_sequence-methode and the drop_sequence-method of Operations from alembic. Now I'm making unittests to test my functionality. I want to alter/interrogate the sequence that I created before. But how?

    def create_sequence(self, sequence):
        kwargs_list = sequence
        self.oprt.execute(CreateSequence(sequence))

    def drop_sequence(self, sequence_name):
        self.oprt.drop_sequence(sequence_name)

self.oprt is initialized like this:

 engine = create_engine(connection_string, echo=True)  # echo=True to get executed sql
 conn = engine.connect()
 ctx = MigrationContext.configure(connection=conn)
 self.oprt = Operations(ctx)

I tried already to get a Sequence object with the help of the engine-object or an Metadata-Object. It doesn't work yet


Solution

  • Here are some ideas I tested with with postgresql so I'm not sure how many other dbs support these.

    Get next_value() of sequence.

    engine = create_engine(f"postgresql+psycopg2://{username}:{password}@/{db}", echo=True)
    
    Base = declarative_base()
    
    metadata = Base.metadata
    
    seq_name = 'counter'
    
    from sqlalchemy import Sequence
    
    counter_seq = Sequence(seq_name, metadata=metadata)
    
    metadata.create_all(engine)
    
    with Session(engine) as session, session.begin():
        res = session.execute(select(counter_seq.next_value())).scalar()
        assert res > 0
    

    Use inspect and check if sequence name is listed

    from sqlalchemy import inspect
    ins = inspect(engine)
    assert seq_name in ins.get_sequence_names()
    
    

    postgresql only -- check currval manually

    I know there is a way to check the current sequence value in postgresql but it doesn't seem that sqlalchemy support that directly. You could do it manually like this:

    from sqlalchemy.sql import func, select
    
    with Session(engine) as session:
        res = session.execute(select(func.currval(seq_name))).scalar()
        assert res > 0