Search code examples
pythonpostgresqlsqlalchemyfastapi

Comparing different methods of rolling back database changes in pytest tests for SQLAlchemy


I'm working on a project that uses FastAPI and SQLAlchemy asynchronously.
I've written pytest tests for this project and have successfully implemented database rollback after each test run.
I've found two different implementation methods, but I'm unsure about the differences between them. Are both methods correct, or is one potentially problematic?

conftest.py

# pyproject.toml
#
# pytest = "^8.3.2"
# pytest-asyncio = "==0.21.2"
# #pytest-dependency = "^0.6.0"
# pytest-order = "^1.2.1"
#
# [tool.pytest.ini_options]
# addopts = "-s"
# asyncio_mode = "auto"

import asyncio
from urllib.parse import urlparse

import pytest
from sqlalchemy import NullPool
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession

from config import settings
from depends.db import async_session as api_async_session
from main import app

url = urlparse(settings.db)._replace(scheme="postgresql+asyncpg").geturl()


@pytest.fixture(scope="session")
def event_loop(request):
    loop = asyncio.get_event_loop_policy().new_event_loop()
    yield loop
    loop.close()


@pytest.fixture
async def async_session():
    async_db = create_async_engine(url, echo=False, poolclass=NullPool)
    async with async_db.connect() as connection:
        async with connection.begin() as transaction:
            async with AsyncSession(bind=connection) as s:
                app.dependency_overrides[api_async_session] = lambda: s
                yield s
            await transaction.rollback()

# Instead of using a connection pool, bind a specific connection to the event loop. 
# If you don't set an event loop policy or using pytest-asyncio 0.23, 
# each test will start a new event loop, causing asyncpg triggering exceptions.
@pytest.fixture
async def async_session2():
    async_db = create_async_engine(url, echo=False, poolclass=NullPool)
    async with async_db.connect() as connection:
        transaction = await connection.begin()
        async with AsyncSession(bind=connection, join_transaction_mode="create_savepoint") as s:
            app.dependency_overrides[api_async_session] = lambda: s
            yield s
        await transaction.rollback()

I've also checked the official documentation for create_savepoint, but it's too difficult to understand. Even after looking into the souced code of AsyncTransaction's __aenter__ method, I'm still uncertain.


Solution

  • The advantage of correctly using savepoints in tests is that you can explicitly use commit/rollback inside your application code during a test and still be able to rollback everything at the end of the test by using the original transaction. Versus with just a single transaction level you cannot explicitly use rollback and commit in your application (in such cases usually you rollback or commit at the end of the request life cycle).

    There is a sync example of using savepoints here (and it looks like it matches your 2nd example):

    https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#joining-a-session-into-an-external-transaction-such-as-for-test-suites

    I think the second example is better if you have postgresql. Although you shouldn't need to explicitly use commit many times a request life cycle very often. Committing quickly can help if you need to open a lock, such as this fabricated example:

    • SELECT orders.* FOR UPDATE WHERE order_id = 15
    • quickly do book keeping to prepare the order
    • order.state = "ready_to_ship"
    • session.commit() to free that order for other changes
    • do other related work but without the lock

    In this case the first example probably wouldn't work because your explicit commit call would break your test's rollback. Whereas the second example would commit against a savepoint so the outer true transaction would be rolled back.

    Update

    There are actually even more subtle differences between "conditional_savepoint" (the default), "rollback_only", "create_savepoint" and "control_fully".

    In the default case while using postgresql if you didn't create a savepoint already then "rollback_only" is used. This ignores commits within an active transaction but rollbacks are propagated up to the outer transaction. So the first example above will not let subsequent commits or interaction to work. Whereas while using "create_savepoint" the session will continue to function.

    The example below demonstrates a case where the code will fail:

    import os
    import contextlib
    
    import pytest
    from sqlalchemy import (
        Column,
        String,
        BigInteger,
        create_engine,
    )
    from sqlalchemy.sql import (
        select,
        or_,
    )
    from sqlalchemy.orm import (
        declarative_base,
        Session,
    )
    
    
    def get_engine(env):
        return create_engine(f"postgresql+psycopg2://{env['DB_USER']}:{env['DB_PASSWORD']}@{env['DB_HOST']}:{env['DB_PORT']}/{env['DB_NAME']}", echo=True, echo_pool='debug')
    
    Base = declarative_base()
    
    
    class Order(Base):
        __tablename__ = 'orders'
        id = Column(BigInteger, primary_key=True)
        state = Column(String, nullable=False)
    
    
    def main():
        #
        # We run this script to setup the testing database.
        #
        engine = get_engine(os.environ)
    
        with engine.begin() as conn:
            Base.metadata.create_all(conn)
    
        with engine.connect() as conn:
            populate(conn)
    
        run_test(engine)
    
    
    def populate(conn):
        with Session(conn) as session:
            session.add_all([Order(id=1, state='start'), Order(id=2, state='processing'), Order(id=3, state='finished')])
            session.commit()
    
    
    def run_test(engine):
    
        # the session we get should already be bound to a connection
        # with a active transaction
        with our_session_maker(engine) as session:
            order = session.execute(select(Order)).scalars().first()
            order.state = 'error'
            # Actually write changes to the db.
            session.flush()
            # Now try to roll them back.
            session.rollback()
            # Try to keep using the session.
            # Now try to make another change but commit this time.
            order.state = 'fixed'
            session.commit()
    
        # When this session is created it should be "clean".
        # Ie. not in a transaction.
        with Session(engine) as session:
            # The state change should never occur.
            assert not session.execute(select(Order).where(or_(Order.state == 'error', Order.state == 'fixed'))).scalars().all()
    
    
    @contextlib.contextmanager
    def our_session_maker(engine):
        with engine.connect() as connection:
            with connection.begin() as transaction:
                # one of conditional_savepoint, create_savepoint, control_fully, rollback_only.
                # The default is join_transaction_mode='conditional_savepoint'.
                with Session(bind=connection) as s:
                    yield s
                    transaction.rollback()
    
    
    if __name__ == '__main__':
        main()
    

    This produces a traceback like this (shortened). You can see the exception occurs on the second commit.

    Traceback (most recent call last):
      File "/app/scripts/testing_sync_session_pytest.py", line 89, in <module>
        main()
      File "/app/scripts/testing_sync_session_pytest.py", line 45, in main
        run_test(engine)
      File "/app/scripts/testing_sync_session_pytest.py", line 68, in run_test
        session.commit()
    ...
    sqlalchemy.exc.InvalidRequestError: Can't operate on closed transaction inside context manager.  Please complete the context manager before emitting further commands.
    
    

    The logs in this case stop after the rollback but you can see that:

    • transaction is started
    • database is changed (via flush)
    • rollback occurs
    • immediate error when session tries to access the database again
    2024-08-04 22:51:49,974 INFO sqlalchemy.engine.Engine BEGIN (implicit)
    2024-08-04 22:51:49,975 INFO sqlalchemy.engine.Engine SELECT orders.id, orders.state 
    FROM orders
    2024-08-04 22:51:49,975 INFO sqlalchemy.engine.Engine [generated in 0.00015s] {}
    2024-08-04 22:51:49,976 INFO sqlalchemy.engine.Engine UPDATE orders SET state=%(state)s WHERE orders.id = %(orders_id)s
    2024-08-04 22:51:49,976 INFO sqlalchemy.engine.Engine [generated in 0.00008s] {'state': 'error', 'orders_id': 1}
    2024-08-04 22:51:49,977 INFO sqlalchemy.engine.Engine ROLLBACK
    

    If join_transaction_mode='create_savepoint' is used this example runs as expected. Ie.

    
    @contextlib.contextmanager
    def our_session_maker(engine):
        with engine.connect() as connection:
            with connection.begin() as transaction:
                with Session(bind=connection, join_transaction_mode='create_savepoint') as s:
                    yield s
                    transaction.rollback()
    

    Looking at the logs you can see (while using 'create_savepoint'):

    • the transaction begins
    • a savepoint is started
    • the first change is made (via flush)
    • that change is rolled back
    • another savepoint is started
    • that change is commited (savepoint is released)
    • the entire transaction is rolled back (as we would want while using pytest)
    2024-08-04 22:48:12,020 INFO sqlalchemy.engine.Engine BEGIN (implicit)
    2024-08-04 22:48:12,020 INFO sqlalchemy.engine.Engine SAVEPOINT sa_savepoint_1
    2024-08-04 22:48:12,020 INFO sqlalchemy.engine.Engine [no key 0.00006s] {}
    2024-08-04 22:48:12,020 INFO sqlalchemy.engine.Engine SELECT orders.id, orders.state 
    FROM orders
    2024-08-04 22:48:12,020 INFO sqlalchemy.engine.Engine [generated in 0.00007s] {}
    2024-08-04 22:48:12,021 INFO sqlalchemy.engine.Engine UPDATE orders SET state=%(state)s WHERE orders.id = %(orders_id)s
    2024-08-04 22:48:12,021 INFO sqlalchemy.engine.Engine [generated in 0.00008s] {'state': 'error', 'orders_id': 1}
    2024-08-04 22:48:12,022 INFO sqlalchemy.engine.Engine ROLLBACK TO SAVEPOINT sa_savepoint_1
    2024-08-04 22:48:12,022 INFO sqlalchemy.engine.Engine [no key 0.00007s] {}
    2024-08-04 22:48:12,022 INFO sqlalchemy.engine.Engine SAVEPOINT sa_savepoint_2
    2024-08-04 22:48:12,022 INFO sqlalchemy.engine.Engine [no key 0.00006s] {}
    2024-08-04 22:48:12,023 INFO sqlalchemy.engine.Engine SELECT orders.id AS orders_id 
    FROM orders 
    WHERE orders.id = %(pk_1)s
    2024-08-04 22:48:12,023 INFO sqlalchemy.engine.Engine [generated in 0.00007s] {'pk_1': 1}
    2024-08-04 22:48:12,023 INFO sqlalchemy.engine.Engine UPDATE orders SET state=%(state)s WHERE orders.id = %(orders_id)s
    2024-08-04 22:48:12,023 INFO sqlalchemy.engine.Engine [cached since 0.002251s ago] {'state': 'fixed', 'orders_id': 1}
    2024-08-04 22:48:12,024 INFO sqlalchemy.engine.Engine RELEASE SAVEPOINT sa_savepoint_2
    2024-08-04 22:48:12,024 INFO sqlalchemy.engine.Engine [no key 0.00005s] {}
    2024-08-04 22:48:12,024 INFO sqlalchemy.engine.Engine ROLLBACK
    

    Conclusion

    It would be pretty easy to use the default, "conditional_savepoint", and not notice because rollback is more rare than commit. My original fabricated example actually works fine but if you use "control_fully" then it will fail. It seems that for deterministic behavior during tests you would want to use "create_savepoint" if that was available to you otherwise code will not work correctly that would actually work in production.

    More Info

    There is some context in the what's new in 2.0 section of the documentation at new-transaction-join-modes-for-session

    Also as mentioned above there is an example in the documentation at joining-a-session-into-an-external-transaction-such-as-for-test-suites