Search code examples
pythonunit-testingpython-asynciopython-unittest

Unit test async method python


Trying to learn about unit testing for async method

Let's say I have this that I want to test

async def run(self):
    while True:
        async with DBHook(mssql_conn_id=self.db_conn).get_cursor() as cursor:
            cursor.execute(self.sql)
            rows = cursor.fetchone()
            if rows[0] > 0:
                yield TriggerEvent(True)
            await asyncio.sleep(self.sleep_interval)

This is in a class of course.

Now, I would like to assert that

asyncio is called when number of rows affected by fetchone is zero

So, I am trying to write to test like this (tried other variations as well)

import aiounittest
import mock


import mymodule as st


class SQLTriggerTests(aiounittest.AsyncTestCase):

    @mock.patch("mymodule.DBHook")
    @mock.patch("mymodule.asyncio")
    async def test_run(self, mock_asyncio, mock_MsSqlIntegratedHook):
        my_obj = mymodule.classname(sql="select blabla", mssql_conn_id="db")  
        conn=mock_MsSqlIntegratedHook.return_value.get_cursor.return_value.__enter__.return_value
        conn.fetchone.return_value= (0,)
        res = my_obj.run()
        mock_asyncio.sleep.assert_called()

The assertion fails saying asyncio is not called.

I also tried using await in the call to my_obj.run(), but that gives me an error saying

can't be awaited on a generator object.

How should I test this properly. Python version:3.9

Update 1: since last time, I have made some progress but still unable to successfully test the async method. I now fixed the await problem with the generator by awaiting on the anext as the async generator can't be awaited on, but now I can't mock cursor execute or fetchone. I have tried mocking the coroutine like the following (following some S.O posts).

Basically now I can't seem to be able to mock rows = cursor.fetchone() and therefore the test fails when it encounters if rows[0] > 0

def get_mock_coro(return_value):
    m = mock.MagicMock()
    @asyncio.coroutine
    def mock_coro(*args, **kwargs):
        return m(*args, **kwargs)
    
    mock_coro.mock = m
    mock_coro.execute = mock.MagicMock()
    mock_coro.fetchone = mock.MagicMock()
    mock_coro.fetchone.return_value = return_value

    return Mock(wraps=mock_coro)

@mock.patch("plugin.triggers.sql_trigger.DBHook")
@mock.patch("plugin.triggers.sql_trigger.asyncio")
async def test_run(self, mock_asyncio, mock_MsSqlIntegratedHook):
    sql_trigger_test_obj = st.SQLTrigger(sql="select blabla", mssql_conn_id="conn")  
    mock_record = (0)
    mock_MsSqlIntegratedHook().get_cursor.return_value.__aenter__.return_value=get_mock_coro(mock_record)
    gen = sql_trigger_test_obj.run()
    res=await (gen.__anext__())
    mock_asyncio.sleep.assert_called()


Solution

  • Ok, so I found a way to test this. This was through debugging and also following some bits and pieces of information from different posts on S.O

    1. First I could not await an async generator, so I had to first do ___anext__() and then await.
    2. Then I had trouble mocking the fetchone and execute on the cursor, as they were coroutines.

    This is how I finally fixed it. This works, however if there is a better way to do this , then I am all eyes and ears.

    import asyncio
    from unittest.mock import Mock
    
    import aiounittest
    import mock
    
    import mymodule as st
    
    
    def get_mock_asyncio():
        m = mock.MagicMock()
        @asyncio.coroutine
        def mock_coro(*args, **kwargs):
            return m(*args, **kwargs)
        
        mock_coro.mock = m
        mock_coro.sleep = mock.MagicMock()
    
        return Mock(wraps=mock_coro)
    
    def get_mock_sql_corountine(return_value):
        m = mock.MagicMock()
        @asyncio.coroutine
        def mock_coro(*args, **kwargs):
            return m(*args, **kwargs)
        
        mock_coro.mock = m
        mock_coro.execute = mock.MagicMock()
        mock_coro.fetchone = mock.MagicMock(return_value=return_value)
    
        return Mock(wraps=mock_coro)
    
    class SQLTriggerTests(aiounittest.AsyncTestCase):
        
        @mock.patch("mymodule.DBHook", autospec=True)
        @mock.patch("mymodule.asyncio", autospec=True)
        async def test_run(self, mock_asyncio, mock_MsSqlIntegratedHook):
            sql_trigger_test_obj = st.SQLTrigger(sql="select blabla", mssql_conn_id="dbconn")  
            mock_record = (1,)
            mock_MsSqlIntegratedHook.return_value.get_cursor.return_value.__aenter__.return_value = get_mock_sql_corountine(mock_record)
            mock_asyncio.sleep.return_value = get_mock_asyncio()
            gen = sql_trigger_test_obj.run()
            await (gen.__anext__())
            mock_asyncio.sleep.assert_not_called()