Search code examples
pythoninheritanceunit-testing

Python unittest TestCase with inheritance


Currently I have many similar unittest TestCases. Each TestCase contains both data (input values + expected output values) and logic (call the SUT and compare the actual output with the expected output).

I would like to separate the data from the logic. Thus I want a base class that only contains the logic and a derived class that contains only the data. I came up with this so far:

import unittest

class MyClass():
    def __init__(self, input):
        self.input = input
    def get_result(self):
        return self.input * 2

class TestBase(unittest.TestCase):
    def check(self, input, expected_output):
        obj = self.class_under_test(input)
        actual_output = obj.get_result()
        self.assertEqual(actual_output, expected_output)

    def test_get_result(self):
        for value in self.values:
            self.check(value[0], value[1])

class TestMyClass(TestBase):
    def __init__(self, methodName='runTest'):
        unittest.TestCase.__init__(self, methodName)        
        self.class_under_test = MyClass
        self.values = [(1, 2), (3, 6)]

unittest.main(exit = False)

But this fails with the following error:

AttributeError: 'TestBase' object has no attribute 'values'

Two questions:

  • Is my 'design' any good?
  • What's still needed to get it working?

Solution

  • The design is (more or less) fine -- the one "hiccup" is that when unittest looks at all TestCase classes and runs the methods that start with "test" on them. You have a few options at this point.

    One approach would be to specify the class under test and values as attributes on the class. Here, if possible, you'll want the values to be immutable...

    class TestBase(unittest.TestCase):
    
        def check(self, input, expected_output):
            obj = self.class_under_test(input)
            actual_output = obj.get_result()
            self.assertEqual(actual_output, expected_output)
    
        def check_all(self):
            for value in self.values:
                self.check(value[0], value[1])
    
    class TestMyClass1(TestBase):
        values = ((1, 2), (3, 4))
        class_under_test = MyClass1
    
        def test_it(self):
            self.check_all()
    
    class TestMyClass2(TestBase):
        values = (('a', 'b'), ('d', 'e'))
        class_under_test = MyClass2
    
        def test_it(self):
            self.check_all()