Search code examples
pythonunit-testingpytest

How to mock a base class entirely when testing the derived class in Python


I want to test a method of a derived class in Python. The base class __init__ function is not well-implemented for testing its derived classes, for reasons too complicated to explain. I would like to just mock the entire base class, so that the derived class inherits from the mock instead of the annoying base class.

Here is a torn-down example of the classes:

class ComplexAnnoyingBaseClass:
    def __init__(self):
        print("do not execute this class in the unit test")
    def get_value_x(self):
        return 4
    def get_value_y(self):
        return 5

def calculate_value(x, y):
    return x * y

class MyDerivedClass(ComplexAnnoyingBaseClass):
    def get_value(self, min_x, min_y):
        value = 0
        x = self.get_value_x()
        y = self.get_value_y()
        if x > min_x and y > min_y:
            value = calculate_value(x, y)
        return value

And a simple example test:

import pytest
from unittest.mock import patch, Mock

from server.example_code import (
    ComplexAnnoyingBaseClass,
    MyDerivedClass
)

def test_get_value():
    # Patch the base class with a mock object
    with patch('server.example_code.ComplexAnnoyingBaseClass',
               spec=MyDerivedClass) as base_class_mock:
        # Create a mock instance of the derived class
        my_instance = MyDerivedClass()

        # Mock the get_value_x() and get_value_y() methods
        base_class_mock.get_value_x.return_value = 2
        base_class_mock.get_value_y.return_value = 3

        # Test your code here
        result = my_instance.get_value(1, 1)
        assert result == 6

When the test ran, it printed the "do not execute" message. Also, the call to my_instance.get_value(1,1) returned 20.

This variation on the test also failed in the exact same ways:

def test_get_value():
    # Create a mock instance of ComplexAnnoyingBaseClass
    base_class_mock = Mock()
    base_class_mock.get_value_x.return_value = 2
    base_class_mock.get_value_y.return_value = 3

    # Patch the base class with the mock object
    with patch('server.example_code.ComplexAnnoyingBaseClass',
            return_value=base_class_mock):
        # Create an instance of MyDerivedClass
        my_instance = MyDerivedClass()

        # Test your code here
        result = my_instance.get_value(1, 1)
        assert result == 6

How do I mock the base class entirely so that its __init__ function is not called during the test?


Solution

  • By the time you are calling patch, MyDerivedClass already has its own internal reference (in MyDerivedClass.__bases__) to the base class.

    You don't need to patch that reference though; you simply need to patch BaseClass attributes.

    with (patch.object(ComplexAnnoyingBaseClass, '__init__', lambda self: None),
          patch.object(ComplexAnnoyingBaseClass, 'get_value_x', lambda self: 2),
          patch.object(ComplexAnnoyingBaseClass, 'get_value_y', lambda self: 3)):
    
        my_instance = MyDerivedClass()
    
        # Test your code here
        result = my_instance.get_value(1, 1)
        assert result == 6
    

    Replacing ComplexAnnoyingBaseClass in MyDerivedClass.__bases__ is possible, but tricky. __bases__ must be a tuple of classes, not a mock or tuple of mocks, and patch.object(MyDerivedClass, '__bases__', (NewBase, )) apparently tries to remove __bases__ when the patch is finished before restoring the original value, and Python doesn't allow __bases__ to be removed. You can be a bit blunter

    MyDerivedClass.__bases__ = (NewBase,)
    # test code here
    MyDerivedClass.__bases = (ComplexAnnoyingBaseClass,)
    

    but I recommend against such manual patching.