Search code examples
pythonunit-testingmockingpytestmonkeypatching

Pytest: mock/patch sys.stdin in program using threading with python


I've acquired some code that I need to test before refactoring. It uses deep recursion so sets new limits and then runs itself in a fresh thread:

sys.setrecursionlimit(10**6)
threading.stack_size(2**27)
...
threading.Thread(target=main).start()

The code relies heavily on sys.stdin and sys.stdout e.g.

class SpamClass:
    def read(self):
        self.n = int(sys.stdin.readline())
        ...
        for i in range(self.n):
            [a, b, c] = map(int, sys.stdin.readline().split())
    ...
    def write(self)
        print(" ".join(str(x) for x in spam()))

To test the code, I need to pass in the contents of a series of input files and compare the results to the contents of some corresponding sample output files.

So far, I've tried three or four different types of mocking and patching without success. My other tests are all written for pytest, so it would be a real nuisance to have to use something else.

I've tried patching module.sys.stdin with StringIO, which doesn't seem to work because pytest's capsys sets sys.stdin to null and hence throws an error despite the patch.

I've also tried using pytest's monkeypatch fixture to replace the module.SpamClss.read method with a function defined in the test, but that produces a segmentation error due, I think, to the thread exiting before the test (or …?).

'pytest test_spam.py' terminated by signal SIGBUS (Misaligned address error)

Any suggestions for how to do this right? Many thanks.


Solution

  • Well, I still don't know what the problem was or if I'm doing this right, but it works for now. I'm not confident the threading aspect is working correctly, but the rest seems fine.

    @pytest.mark.parametrize("inputs, outputs", helpers.get_sample_tests('spampath'))
    def test_tree_orders(capsys, inputs, outputs):
        """
        """
        with patch('module.sys.stdin', StringIO("".join(inputs))):
            module.threading.Thread(target=module.main()).start()
    
        captured = capsys.readouterr()
    
        assert "".join(outputs) == captured.out
    

    For anyone else who's interested, it helps to do your debugging prints as print(spam, file=sys.stderr), which you can then access in the test as captured.err, cf. the captured.out used for testing.