Search code examples
pythonmockingpython-unittest

Python: how can I override a complicated function during unittest?


I am using Python's unittest module to test a script I am writing.

The script contains a loop like this:

// my_script.py

def my_loopy_function():
    aggregate_value = 0
    for x in range(10):
        aggregate_value = aggregate_value + complicated_function(x)
    return aggregate_value

def complicated_function(x):
    a = do()
    b = something()
    c = complicated()
    return a + b + c

I have no problems using unittest to test complicated_function. But I would like to test my_loopy_function by overriding complicated_function.

I tried modifying my script so that my_loopy_function takes complicated_function as an optional parameter so that I can pass in a simple version from the test:

// my_modified_script.py

def my_loopy_function(action_function=None):
    if action_function is not None:
        complicated_function = action_function
    aggregate_value = 0
    for x in range(10):
        aggregate_value = aggregate_value + complicated_function(x)
    return aggregate_value

def complicated_function(x):
    a = do()
    b = something()
    c = complicated()
    return a + b + c

// test_my_script.py

from myscript import my_loopy_function

class TestMyScript(unittest.TestCase):
    test_loopy_function(self):
        def simple_function():
            return 1
    self.assertEqual(10, my_loopy_function(action_function=simple_function))

It has not worked as I had hoped, are there any suggestions on how I should be doing this?


Solution

  • In the end I used Python's mock, which allows me to override complicated_function without having to adjust the original code in any way.

    Here is the original script, and note that complicated_function is not passed in to my_loopy_function as an 'action_function' parameter (which was what I tried in my earlier solutions):

    // my_script.py
    
    def my_loopy_function():
        aggregate_value = 0
        for x in range(10):
            aggregate_value = aggregate_value + complicated_function(x)
        return aggregate_value
    
    def complicated_function(x):
        a = do()
        b = something()
        c = complicated()
        return a + b + c
    

    and here is the script I am using to test it:

    // test_my_script.py
    
    import unittest
    import mock
    from my_script import my_loopy_function
    
    class TestMyModule(unittest.TestCase):
        @mock.patch('my_script.complicated_function')
        def test_1(self, mocked):
            mocked.return_value = 1
            self.assertEqual(10, my_loopy_function())
    

    This works just as I had wanted:

    1. I am able to substitute functions with a simpler version of themselves that I can more easily test,
    2. I do not need to alter my original code in any way (such as I was trying -- which was effectively by passing in function pointers), the mock module gives me post-coding access to the innards.

    Thanks to austin for his suggestion to use mock. BTW I am using Python 2.7 and therefore used the pip-installable mock from PyPI.