Search code examples
pythonpytestassert

Does pytest support the use of function factories in test files?


Example test.py file:

import torch

def one():
    return torch.tensor(0.0132005215)

def two():
    return torch.tensor(4.4345855713e-05)

def three():
    return torch.tensor(7.1525573730e-07)


def test_method(method, expected_value):
    value = method()
    assert(torch.isclose(value, expected_value))

def test_one():
    test_method(one, torch.tensor(0.0132005215))

def test_two():
    test_method(two, torch.tensor(4.4345855713e-05))

def test_three():
    test_method(three, torch.tensor(7.1525573730e-07))
    # test_method(three, torch.tensor(1.0))

if __name__ == '__main__':
    test_one()
    test_two()
    test_three()

Basically, I have a few functions I want to test (here called one, two and three), all with the same signature but different internals. Thus instead of writing functions test_one(), test_two(), etc. and thus duplicating code, I wrote a "function factory" (is this the right term?) test_method, which takes as an input the function, the expected results and returns the result of an assert command.

As you can see, right now the tests are performed manually: I run the script test.py, look at the screen, and if no Assertion error gets printed, I'm happy. Of course I'd like to improve on this by using pytest, since I've been told it's one of the simplest and most used Python testing frameworks. The problem is that by looking at pytest documentation I got the impression that pytest will try running all functions whose name starts with test_. Of course, testing test_method itself doesn't make any sense. Can you help me refactor this test script, so that I can run it with pytest?


Solution

  • In pytest, you can use test parametrization to achieve this. In your case you have to provide the varying parameters to the test:

    import pytest
    
    @pytest.mark.parametrize("method, expected_value",
                             [(one, 0.0132005215),
                              (two, 4.4345855713e-05),
                              (three, 7.1525573730e-07)])
    def test_method(method, expected_value):
        value = method()
        assert(torch.isclose(value, expected_value))
    

    If you run python -m pytest -rA (see the documentation for output options), you will get the output of three tests, something like:

    ======================================================= PASSES ========================================================
    =============================================== short test summary info ===============================================
    PASSED test.py::test_method[one-0.0132005215]
    PASSED test.py::test_method[two-4.4345855713e-05]
    PASSED test.py::test_method[three-7.152557373e-07]
    ================================================== 3 passed in 0.07s ==================================================
    

    if you don't like the fixture names, you can adapt them:

    @pytest.mark.parametrize("method, expected_value",
                              [(one, 0.0132005215),
                              (two, 4.4345855713e-05),
                              (three, 7.1525573730e-07),
                              ],
                             ids=["one", "two", "three"])
    ...
    

    This gives you instead:

    =============================================== short test summary info ===============================================
    PASSED test.py::test_method[one]
    PASSED test.py::test_method[two]
    PASSED test.py::test_method[three]
    ================================================== 3 passed in 0.06s ==================================================