Search code examples
pythonmockingoptional-parametersmonkeypatchingfunctools

Python mock.patch.object with functool.partial bound arguments possible?


How to solve this? Patch a objects method with another signature (eg. an additional argument. I've tried to bound the optional argument, but this does not seem to work. I can not use plain monkey patching here, since the the patched class is called under in a location where I can not patch it that way otherwise.

Any help appreciated.

import mock
import functools

# this class lives in another (unchangeable) module, __len__ method has to be patched
class ToOverride(object):
    def __len__(self):
        raise NotImplementedError()

# this code is changeable
def my_len(self, arg):
    return arg+1

my_len_bound = functools.partial(my_len, arg=1)

with mock.patch.object(ToOverride, '__len__', my_len_bound):
    inst = ToOverride()
    print len(inst) # expected output: 2

When invoking the context mock.patch.object, I get the following error:

TypeError                                 Traceback (most recent call last)
<ipython-input-7-bfdb41d8628f> in <module>()
      1 with mock.patch.object(ToOverride, '__len__', my_len_bound):
      2     inst = ToOverride()
----> 3     print len(inst)

TypeError: my_len() takes exactly 2 arguments (1 given)

But invoking my_len with None as first argument works as expected (print out 2).

Assuming it would be possible to simply monkey patch, it works, if len is called manually with the instance as first argument. But this is of course undesired:

ToOverride.__len__ = my_len_bound
inst = ToOverride()
print( inst.__len__(inst)) # 2

Solution

  • Old question, but anyway: Instead of functools.partial, use functools.partialmethod

    As in:

    from unittest import mock
    import functools
    
    class ToOverride(object):
        def __len__(self):
            raise NotImplementedError()
    
    def my_len(self, arg):
        return arg+1
    
    my_len_bound = functools.partialmethod(my_len, arg=1)
    
    with mock.patch.object(ToOverride, '__len__', my_len_bound):
        inst = ToOverride()
        assert len(inst) == 2