Search code examples
pythonmockingpython-mock

How to mock a function imported directly by the tested module without knowing the module name in python


Assume I have a function defined in a module:

module_a.py

def foo():
    return 10

And I want to create an API to patch the function:

patcher.py

import mock

class Patcher(object):

    def __enter__(self):
        self.patcher = mock.patch('module_a.foo',
                                  mock.Mock(return_value=15))

        self.patcher.start()

    def __exit__(self, *args):
        self.patcher.stop()

The thing is, I don't know what is the name of the module that will use my API. so a test looking like this:

test1.py

from patcher import Patcher
import module_a

with Patcher():
    assert module_a.foo() == 15

will work. But a test written like this:

test2.py

from patcher import Patcher
from module_a import foo

with Patcher():
    assert foo() == 15

will Fail.

Is there anyway not making the API user to write it's tests and modules(!) like the first option?


Solution

  • There is a way to "patch" over a function without knowing where the patch is occurring. That was the requirement for my question since the patcher is my library API, and I don't want to be given a path to each test module using my library.

    The solution I found was to pass on all loaded modules and try to find foo in them, and then changing it - sorta implementing patch by myself. If the import will happen only after the Patcher is started, I loaded the module myself, and changed it too.

    Now the code will look like this:

    Patcher

    import sys
    import mock
    
    from module_a import foo as _orig_foo
    
    import module_a
    
    class Patcher(object):
    
        def __init__(self):
            self.undo_set = set()
            self.fake_foo = mock.Mock(return_value=15)
    
        def __enter__(self):
            modules = [
                module for mod_name, module in sys.modules.items() if
                mod_name is not None and module is not None and
                hasattr(module, '__name__') and
                module.__name__ not in ('module_a', 'patcher')
            ]
    
            for module in modules:
              for attr in dir(module):
                    try:
                        attribute_value = getattr(module, attr)
                    except (ValueError, AttributeError, ImportError):
                        # For some libraries, this happen.
                        continue
    
                    if id(attribute_value) == id(_orig_foo):
                        setattr(module, attr, self.fake_foo)
                        self.undo_set.add((module, attr, attribute_value))
    
            # Solve for future imports
            module_a.foo = self.fake_foo
    
    
        def __exit__(self, *args):
            module_a.foo = _orig_foo
            for mod, attr, val in self.undo_set:
                setattr(mod, attr, val)
            self.undo_set = set()