I'm working on a project that uses FastAPI and SQLAlchemy asynchronously.
I've written pytest tests for this project and have successfully implemented database rollback after each test run.
I've found two different implementation methods, but I'm unsure about the differences between them. Are both methods correct, or is one potentially problematic?
# pyproject.toml
#
# pytest = "^8.3.2"
# pytest-asyncio = "==0.21.2"
# #pytest-dependency = "^0.6.0"
# pytest-order = "^1.2.1"
#
# [tool.pytest.ini_options]
# addopts = "-s"
# asyncio_mode = "auto"
import asyncio
from urllib.parse import urlparse
import pytest
from sqlalchemy import NullPool
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from config import settings
from depends.db import async_session as api_async_session
from main import app
url = urlparse(settings.db)._replace(scheme="postgresql+asyncpg").geturl()
@pytest.fixture(scope="session")
def event_loop(request):
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture
async def async_session():
async_db = create_async_engine(url, echo=False, poolclass=NullPool)
async with async_db.connect() as connection:
async with connection.begin() as transaction:
async with AsyncSession(bind=connection) as s:
app.dependency_overrides[api_async_session] = lambda: s
yield s
await transaction.rollback()
# Instead of using a connection pool, bind a specific connection to the event loop.
# If you don't set an event loop policy or using pytest-asyncio 0.23,
# each test will start a new event loop, causing asyncpg triggering exceptions.
@pytest.fixture
async def async_session2():
async_db = create_async_engine(url, echo=False, poolclass=NullPool)
async with async_db.connect() as connection:
transaction = await connection.begin()
async with AsyncSession(bind=connection, join_transaction_mode="create_savepoint") as s:
app.dependency_overrides[api_async_session] = lambda: s
yield s
await transaction.rollback()
I've also checked the official documentation for create_savepoint
, but it's too difficult to understand. Even after looking into the souced code of AsyncTransaction's __aenter__
method, I'm still uncertain.
The advantage of correctly using savepoints in tests is that you can explicitly use commit/rollback inside your application code during a test and still be able to rollback everything at the end of the test by using the original transaction. Versus with just a single transaction level you cannot explicitly use rollback and commit in your application (in such cases usually you rollback or commit at the end of the request life cycle).
There is a sync example of using savepoints here (and it looks like it matches your 2nd example):
I think the second example is better if you have postgresql. Although you shouldn't need to explicitly use commit many times a request life cycle very often. Committing quickly can help if you need to open a lock, such as this fabricated example:
SELECT orders.* FOR UPDATE WHERE order_id = 15
order.state = "ready_to_ship"
session.commit()
to free that order for other changesIn this case the first example probably wouldn't work because your explicit commit call would break your test's rollback. Whereas the second example would commit against a savepoint so the outer true transaction would be rolled back.
There are actually even more subtle differences between "conditional_savepoint" (the default), "rollback_only", "create_savepoint" and "control_fully".
In the default case while using postgresql if you didn't create a savepoint already then "rollback_only" is used. This ignores commits within an active transaction but rollbacks are propagated up to the outer transaction. So the first example above will not let subsequent commits or interaction to work. Whereas while using "create_savepoint" the session will continue to function.
The example below demonstrates a case where the code will fail:
import os
import contextlib
import pytest
from sqlalchemy import (
Column,
String,
BigInteger,
create_engine,
)
from sqlalchemy.sql import (
select,
or_,
)
from sqlalchemy.orm import (
declarative_base,
Session,
)
def get_engine(env):
return create_engine(f"postgresql+psycopg2://{env['DB_USER']}:{env['DB_PASSWORD']}@{env['DB_HOST']}:{env['DB_PORT']}/{env['DB_NAME']}", echo=True, echo_pool='debug')
Base = declarative_base()
class Order(Base):
__tablename__ = 'orders'
id = Column(BigInteger, primary_key=True)
state = Column(String, nullable=False)
def main():
#
# We run this script to setup the testing database.
#
engine = get_engine(os.environ)
with engine.begin() as conn:
Base.metadata.create_all(conn)
with engine.connect() as conn:
populate(conn)
run_test(engine)
def populate(conn):
with Session(conn) as session:
session.add_all([Order(id=1, state='start'), Order(id=2, state='processing'), Order(id=3, state='finished')])
session.commit()
def run_test(engine):
# the session we get should already be bound to a connection
# with a active transaction
with our_session_maker(engine) as session:
order = session.execute(select(Order)).scalars().first()
order.state = 'error'
# Actually write changes to the db.
session.flush()
# Now try to roll them back.
session.rollback()
# Try to keep using the session.
# Now try to make another change but commit this time.
order.state = 'fixed'
session.commit()
# When this session is created it should be "clean".
# Ie. not in a transaction.
with Session(engine) as session:
# The state change should never occur.
assert not session.execute(select(Order).where(or_(Order.state == 'error', Order.state == 'fixed'))).scalars().all()
@contextlib.contextmanager
def our_session_maker(engine):
with engine.connect() as connection:
with connection.begin() as transaction:
# one of conditional_savepoint, create_savepoint, control_fully, rollback_only.
# The default is join_transaction_mode='conditional_savepoint'.
with Session(bind=connection) as s:
yield s
transaction.rollback()
if __name__ == '__main__':
main()
This produces a traceback like this (shortened). You can see the exception occurs on the second commit.
Traceback (most recent call last):
File "/app/scripts/testing_sync_session_pytest.py", line 89, in <module>
main()
File "/app/scripts/testing_sync_session_pytest.py", line 45, in main
run_test(engine)
File "/app/scripts/testing_sync_session_pytest.py", line 68, in run_test
session.commit()
...
sqlalchemy.exc.InvalidRequestError: Can't operate on closed transaction inside context manager. Please complete the context manager before emitting further commands.
The logs in this case stop after the rollback but you can see that:
2024-08-04 22:51:49,974 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2024-08-04 22:51:49,975 INFO sqlalchemy.engine.Engine SELECT orders.id, orders.state
FROM orders
2024-08-04 22:51:49,975 INFO sqlalchemy.engine.Engine [generated in 0.00015s] {}
2024-08-04 22:51:49,976 INFO sqlalchemy.engine.Engine UPDATE orders SET state=%(state)s WHERE orders.id = %(orders_id)s
2024-08-04 22:51:49,976 INFO sqlalchemy.engine.Engine [generated in 0.00008s] {'state': 'error', 'orders_id': 1}
2024-08-04 22:51:49,977 INFO sqlalchemy.engine.Engine ROLLBACK
If join_transaction_mode='create_savepoint'
is used this example runs as expected. Ie.
@contextlib.contextmanager
def our_session_maker(engine):
with engine.connect() as connection:
with connection.begin() as transaction:
with Session(bind=connection, join_transaction_mode='create_savepoint') as s:
yield s
transaction.rollback()
Looking at the logs you can see (while using 'create_savepoint'):
2024-08-04 22:48:12,020 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2024-08-04 22:48:12,020 INFO sqlalchemy.engine.Engine SAVEPOINT sa_savepoint_1
2024-08-04 22:48:12,020 INFO sqlalchemy.engine.Engine [no key 0.00006s] {}
2024-08-04 22:48:12,020 INFO sqlalchemy.engine.Engine SELECT orders.id, orders.state
FROM orders
2024-08-04 22:48:12,020 INFO sqlalchemy.engine.Engine [generated in 0.00007s] {}
2024-08-04 22:48:12,021 INFO sqlalchemy.engine.Engine UPDATE orders SET state=%(state)s WHERE orders.id = %(orders_id)s
2024-08-04 22:48:12,021 INFO sqlalchemy.engine.Engine [generated in 0.00008s] {'state': 'error', 'orders_id': 1}
2024-08-04 22:48:12,022 INFO sqlalchemy.engine.Engine ROLLBACK TO SAVEPOINT sa_savepoint_1
2024-08-04 22:48:12,022 INFO sqlalchemy.engine.Engine [no key 0.00007s] {}
2024-08-04 22:48:12,022 INFO sqlalchemy.engine.Engine SAVEPOINT sa_savepoint_2
2024-08-04 22:48:12,022 INFO sqlalchemy.engine.Engine [no key 0.00006s] {}
2024-08-04 22:48:12,023 INFO sqlalchemy.engine.Engine SELECT orders.id AS orders_id
FROM orders
WHERE orders.id = %(pk_1)s
2024-08-04 22:48:12,023 INFO sqlalchemy.engine.Engine [generated in 0.00007s] {'pk_1': 1}
2024-08-04 22:48:12,023 INFO sqlalchemy.engine.Engine UPDATE orders SET state=%(state)s WHERE orders.id = %(orders_id)s
2024-08-04 22:48:12,023 INFO sqlalchemy.engine.Engine [cached since 0.002251s ago] {'state': 'fixed', 'orders_id': 1}
2024-08-04 22:48:12,024 INFO sqlalchemy.engine.Engine RELEASE SAVEPOINT sa_savepoint_2
2024-08-04 22:48:12,024 INFO sqlalchemy.engine.Engine [no key 0.00005s] {}
2024-08-04 22:48:12,024 INFO sqlalchemy.engine.Engine ROLLBACK
It would be pretty easy to use the default, "conditional_savepoint"
, and not notice because rollback is more rare than commit. My original fabricated example actually works fine but if you use "control_fully"
then it will fail. It seems that for deterministic behavior during tests you would want to use "create_savepoint"
if that was available to you otherwise code will not work correctly that would actually work in production.
There is some context in the what's new in 2.0 section of the documentation at new-transaction-join-modes-for-session
Also as mentioned above there is an example in the documentation at joining-a-session-into-an-external-transaction-such-as-for-test-suites