Search code examples
pythonunit-testingpython-unittest

python testing, how handle 'continue' statement loop


I'm learning how to use Python. I have a function with a conditional inside of it, if an invalid input is provided, it should restart the loop until a valid input is provided.

Unfortunately, this "restarting" behavior is causing an infinite loop within my tests (it circularly provides the wrong input). How can I pause, or break, or limit the output to one instance so I can test the returned string?

function:

def confirm_user_choice(choice: str):
    while True:
        user_response = input(f"\nYou chose '{choice}', is this correct? y/n ")
        if user_response == "y":
            return True
        elif user_response == "n":
            return False
        else:
            print("\nSelect either 'y' (yes) or 'n' (no)")

test:

import unittest
from unittest import mock
from src.utils.utils import addValues, confirm_user_choice


class TestConfirmUserChoice(unittest.TestCase):
    def test_yes(self):
        with mock.patch("builtins.input", return_value="y"):
            result = confirm_user_choice("y")
        self.assertEqual(result, True)

    def test_no(self):
        with mock.patch("builtins.input", return_value="n"):
            result = confirm_user_choice("n")
        self.assertEqual(result, False)

    def test_invalid_input(self):
        with mock.patch("builtins.input", return_value="apple"):   <-- triggers func else case
            result = confirm_user_choice("apple")
        self.assertEqual(result, False)

Solution

  • You have a partial function: on a proper input, it will return a Boolean value, but it may not return at all, and you can't test that an infinite loop is indeed infinite.

    To make it more testable, allow the function to take an optional iterable value that defaults to sys.stdin, allowing you to control what the function reads (and how long it will attempt to do so.)

    def confirm_user_choice(choice: str, responses: Optional[Iterable[str]] = None):
        if responses is None:
            # An infinite stream of calls to input()
            responses = iter(lambda: input(f"\nYou chose '{choice}', is this correct? y/n "), None)
    
        for user_response in responses:
            if user_response == "y":
                return True
            elif user_response == "n":
                return False
            else:
                print("\nSelect either 'y' (yes) or 'n' (no)")
        else:
            # Note: cannot be raised from the default value of responses
            raise ValueError("Unexpected end of responses")
    

    Now your test can simply pass canned lists of responses, and either catch the expected ValueError, or look at the returned Boolean value.

    import unittest
    from src.utils.utils import addValues, confirm_user_choice
    
    
    class TestConfirmUserChoice(unittest.TestCase):
        def test_yes(self):
            result = confirm_user_choice("y", ["y"])
            self.assertTrue(result)
    
        def test_eventual_yes(self):
            result = confirm_user_choice("y", ["apple", "pear", "y"])
            self.assertTrue(result)
    
        def test_no(self):
            result = confirm_user_choice("y", ["n"])
            self.assertFalse(result)
    
        def test_no_valid_input(self):
            with self.assertRaises(ValueError):
                result = confirm_user_choice(["apple"])