Given below implementation, I try to test using an asynchronous session. My attempt goes in the following way:
models.py
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncConnection
class Paginator:
def __init__(
self,
conn: Union[Connection, AsyncConnection],
query: str,
params: dict = None,
batch_size: int = 10
):
self.conn = conn
self.query = query
self.params = params
self.batch_size = batch_size
self.current_offset = 0
self.total_count = None
async def _get_total_count_async(self) -> int:
"""Fetch the total count of records asynchronously."""
count_query = f"SELECT COUNT(*) FROM ({self.query}) as total"
query=text(count_query).bindparams(**(self.params or {}))
result = await self.conn.execute(query)
return result.scalar()
test_models.py
@pytest.fixture(scope='function')
async def async_session():
async_engine=create_async_engine('postgresql+asyncpg://localhost:5432/db')
async_session = sessionmaker(
expire_on_commit=False,
autocommit=False,
autoflush=False,
bind=async_engine,
class_=AsyncSession,
)
async with async_session() as session:
await session.begin()
yield session
await session.rollback()
@pytest.mark.asyncio
async def test_get_total_count_async(async_session):
# Prepare the paginator
paginator = Paginator(
conn=session,
query="SELECT * FROM test_table",
batch_size=2
)
# Perform the total count query asynchronously
total_count = await paginator._get_total_count_async()
# Assertion to verify the result
assert total_count == 0
When I run the command pytest
, I obtained following error: AttributeError: 'async_generator' object has no attribute 'execute'
. I am pretty sure, there is an easy way to do so, but I am unaware of it.
You should pass an instance of AsyncConnection
to the Paginator
class, but you're sending session
itself directly.
To solve the issue there are two possible approaches:
session
to reach the AsyncConnection
within the test
function:@pytest.mark.asyncio
async def test_get_total_count_async(async_session):
async for conn in async_session:
paginator = Paginator(
conn=conn,
query="SELECT * FROM test_table",
batch_size=2
)
...
pytest_asyncio
PyPI package for the fixture:@pytest_asyncio.fixture
async def async_session():
async_engine=create_async_engine('postgresql+asyncpg://localhost:5432/db')
async_session = sessionmaker(
expire_on_commit=False,
autocommit=False,
autoflush=False,
bind=async_engine,
class_=AsyncSession,
)
async with async_session() as session:
await session.begin()
yield session
await session.rollback()
@pytest.mark.asyncio
async def test_get_total_count_async(async_session):
# Prepare the paginator
paginator = Paginator(
conn=session,
query="SELECT * FROM test_table",
batch_size=2
)
# Perform the total count query asynchronously
total_count = await paginator._get_total_count_async()
# Assertion to verify the result
assert total_count == 0
Here's a post regarding this issue.