Search code examples
pythonunit-testinguser-inputpython-unittest

How to write a good unit test for testing a function that involves taking input from the user?


Disclaimer: This is my first question here, so apologies if it's framed poorly, please ask for clarifications if required.

This is the function that I want to test:

def driver():
    n = int(input("Enter number of rows: "))
    m = int(input("Enter number of columns: "))
    if n == 0 or m == 0:
        raise ValueError("Invalid grid size.")
    grid = []
    for row in range(n):
        row_i = input("Enter the space-separated cells of row " + str(row + 1) + ": ")
        row_i = list(map(lambda x: int(x), row_i.split()))
        if len(row_i) != m:
            raise IndexError("Invalid input for the given number of columns.")
        if any([cell not in [0, 1] for cell in row_i]):
            raise ValueError("Invalid value of cell, a cell can only have 0 or 1 as a value.")
        grid.append(row_i)
    print("Initial grid: ")
    print_grid(grid)

The tests I have written are as follows:

class TestGameOfLife(unittest.TestCase):
    def setUp(self):
        self.driver = game_of_life.driver

    @mock.patch('game_of_life.input', create=True)
    def test_driver_invalid_num_rows(self, mocked_input):
        mocked_input.side_effect = ["0", "5"]
        self.assertRaisesRegex(ValueError, "Invalid grid size.", self.driver)

The issue with this is that this test itself and the lines of codes it's trying to test are not included in the coverage. So I reckon it's not supposed to be done this way. Can anyone help me with how I should be testing it instead?

I took inspiration from this post to write my unit tests, but given that neither the code nor the tests are included in coverage, this is probably not the appropriate way to do this in my case.


Solution

  • I left some thoughts in a comment, but to answer your specific question it is possible to provide synthetic stdin and stdout to a function so that you can provide it with test input and read the resulting output.

    Assuming that we have a simplified version of your driver code in driver.py:

    def driver():
        n = int(input("Enter number of rows: "))
        m = int(input("Enter number of columns: "))
        if n == 0 or m == 0:
            raise ValueError("Invalid grid size.")
    
        print(f"grid: {n}x{m}")
    

    We can test it like this:

    import io
    import sys
    import pytest
    
    import driver
    
    
    def test_driver():
        stdin = io.StringIO("10\n10\n")
        stdout = io.StringIO()
        sys.stdin = stdin
        sys.stdout = stdout
        driver.driver()
        assert "grid: 10x10" in stdout.getvalue()
    
    
    def test_driver_invalid_input():
        stdin = io.StringIO("x\ny\n")
        sys.stdin = stdin
        with pytest.raises(ValueError):
            driver.driver()