Search code examples
python-3.xdatabasemockingpython-asyncio

How do I mock asyncio database context managers?


I've been struggling for a bit to mock the typical async database connection setup:

async with aiomysql.create_pool(...) as pool:
    async with pool.acquire() as connection:
        async with connection.cursor() as cursor:
            await cursor.execute("BEGIN")
            ...

My first try for a test function looked about like this:

async def test_database(mocker: pytest_mock.MockerFixture):
    context = mocker.AsyncMock()
    pool = mocker.AsyncMock()
    connection = mocker.AsyncMock()
    cursor = mocker.AsyncMock()
    cursor.fetchall.return_value = [{'Database': 'information_schema'}]
    cursor.fetchone.return_value = {'COUNT(*)': 0}
    cursor.rowcount = 0
    connection.cursor.return_value.__aenter__.return_value = cursor
    pool.acquire.return_value.__aenter__.return_value = connection
    context.__aenter__.return_value = pool
    mocker.patch('aiomysql.create_pool', return_value=context)

    async with aiomysql.create_pool() as p:
        async with p.acquire() as c:
            async with c.cursor() as cur:
                await cur.execute("BEGIN")

If you've been getting AttributeErrors for missing __aenter__s, this post is for you.


Solution

  • The important part to note is that there is no await between the async with and function call since create_pool, acquire, and cursor are synchronous. The test function above will produce new AsyncMock objects that require an await on acquire() etc. to return the next prepared AsyncMock. Instead we want acquire() etc. to return immediately. The solution is to mix Mock/MagicMock and AsyncMock.

    async def test_database(mocker: pytest_mock.MockerFixture):
        context = mocker.AsyncMock()
        pool = mocker.Mock()
        connection = mocker.Mock()
        cursor = mocker.AsyncMock()
        cursor.fetchall.return_value = [{'Database': 'information_schema'}]
        cursor.fetchone.return_value = {'COUNT(*)': 0}
        cursor.rowcount = 0
        connection.cursor.return_value = mocker.AsyncMock()
        connection.cursor.return_value.__aenter__.return_value = cursor
        pool.acquire.return_value = mocker.AsyncMock()
        pool.acquire.return_value.__aenter__.return_value = connection
        context.__aenter__.return_value = pool
        mocker.patch('aiomysql.create_pool', return_value=context)
    
        # calls create_pool synchronously and gets 'context',
        # which is an AsyncMock and facilitates the __aenter__ call,
        # which returns 'pool' as a regular Mock
        async with aiomysql.create_pool() as p:
            # calls 'acquire()' synchronously and gets an anonymous AsyncMock,
            # which facilitates the __aenter__ call,
            # which returns 'connection' as a regular Mock
            async with p.acquire() as c:
                # again, 'cursor()' synchronously, get AsyncMock,
                # __aenter__ to get 'cursor'
                async with c.cursor() as cur:
                    # continue regular operations on AsyncMock object
                    await cur.execute("BEGIN")
    

    Note: aiomysql specific: If you want to use connection.begin() and the likes, add connection.begin = mocker.Mock(), and it'll __call__ the object synchronously.