Search code examples
pythonunit-testingmocking

Python unit-testing - How to patch an async call internal to the method I am testing


Im using unittest.mock for building tests for my python code. I have a method that I am trying to test that contains a async call to another function. I want to patch that async call so that I can just have Mock return a testing value for the asset id, and not actually call the async method. I have tried many things I've found online, but none have worked thus far.

Simplified example below:

test.py

import pytest

from app.create.creations import generate_new_asset
from app.fakeapi.utils import create_asset

from unittest.mock import Mock, patch

@patch("app.fakeapi.utils.create_asset")
@pytest.mark.anyio
async def test_generate_new_asset(mock_create):
    mock_create.return_value = 12345678

    await generate_new_asset()

    ...

creations.py

from app.fakeapi.utils import create_asset
...

async def generate_new_asset()
    ...
    # When I run tests this does not return the 12345678 value, but actually calls the `create_asset` method.
    return await create_asset(...) 

Solution

  • Testing async code is bit tricky. If you are using python3.8 or higher AsyncMock is available.

    Note: it will work only for Python > 3.8

    I think in your case event loop is missing. Here is the code which should work, you may need to do few tweaks. You may also need to install pytest-mock. Having it as fixture will allow you to have mock different values for testing for different scenarios.

    import asyncio
    from unittest.mock import AsyncMock, Mock
    
    @pytest.fixture(scope="module")
    def mock_create_asset(mocker):
       async_mock = AsyncMock()
       mocker.patch('app.fakeapi.utils.create_asset', side_effect=async_mock)
       return async_mock
    
    @pytest.fixture(scope="module")
    def event_loop():
        return asyncio.get_event_loop()
    
     @pytest.mark.asyncio
     async def test_generate_new_asset(mock_create_asset):
        mock_create_asset.return_value = 12345678
        await generate_new_asset()