Search code examples
pythonmockingpytestpython-mock

Pytest + mock: patch does not work without with clause


I'm testing complex logic that require joining of central fact table with 10-20 smaller dimensional tables. I want to mock that 10-20 smaller tables.

How to patch methods return values in a for loop? See code below.

tables.py:

class BaseClass(object):

  def load(path: str):
    ...

class SmallTable1(BaseClass):


  def load():
     super().load(self.path)

class SmallTable20(BaseClass):

  def load():
     super().load(self.path)

test_logic.py

# QUESTION - how to make it work
def test_not_work(datasets):
   for i in range(1, 21):
      table = 'SmallTable' + str(i)
      with mock.patch('some_package.tables.{}.load'.format(table)) as load_mock:
         load_mock.return_value = datasets[table]
      do_joins()  # here my mocks doesn't work


def test_works(datasets):
   with mock.patch('some_package.tables.SmallTable1.load'.format(i)) as load_mock_1:
       load_mock_1.return_value = datasets[SmallTable1]

       with mock.patch('some_package.tables.SmallTable2.load'.format(i)) as load_mock_2:
            load_mock_2.return_value = datasets[SmallTable2]

            .....  # repeat the same 8-18 times more

                                     do_joins()  # here my mocks do work, but with cumbersome code and huge right offset

P.S. alternatively I can try to mock the BaseClass.load, but then I don't know how to return the different data set for different table (class).


Solution

  • Under the assumption that do_join shall be called outside the loop, with all tables patched, you could write a fixture that uses contextlib.ExitStack to setup all mocks:

    from contextlib import ExitStack
    from unittest import mock   
    import pytest
    
    from some_package import do_join
    
    
    @pytest.fixture
    def datasets():
        ...    
    
    @pytest.fixture
    def mock_tables(datasets):
        with ExitStack() as stack:
            for i in range(1, 21):
                table = 'SmallTable' + str(i)
                load_mock = stack.enter_context(
                    mock.patch('some_package.tables.{}.load'.format(table)))
                load_mock.return_value = datasets[table]
            yield
    
    
    def test_join(mock_tables):
        do_join()
    

    This means that all mocks are still active at yield time, and will only be removed after the test is finished.

    If you have pytest-mock installed, you can use the mocker fixture instead:

    @pytest.fixture
    def mock_tables(datasets, mocker):
        for i in range(1, 21):
            table = 'SmallTable' + str(i)
            load_mock = mocker.patch('some_package.tables.{}.load'.format(table))
            load_mock.return_value = datasets[table]
        yield