Search code examples
pythonpytestpython-multithreadingpython-unittest

How to assert that exception is thrown in a Thread based class in Python unit tests


I have code similar to the following written in python - a thread based class

from threading import Thread

class ThreadClass(Thread):

    def __init__(self, li):
        super(ThreadClass, self).__init__()
        self.li = li
        self.ps = list()

    def __validate(self, val):
        if val > 10:
            raise ValueError("Given value is greater than 10")

        return val % 10
    
    def __fun(self):
        ps = list()
        for i in self.li:
            p = self.__validate(i)
            ps.append(p)
        self.ps = ps
    
    def get_ps(self):
        return self.ps

    def run(self):
        self.__fun()

And the following unit test to test a failure scenario

from unittest import TestCase
from thread_class import ThreadClass

class TestThreadClassNegative(TestCase):

    def test_val_greater_than_valid_fail(self):
        get_data = ThreadClass(li = [2,3,5,11,6])
        with self.assertRaises(ValueError):
            get_data.start()
            get_data.join()

The above test fails saying AssertionError: ValueError not raised But I can clearly see the exception being raised.

================================== FAILURES ===================================
__________ TestThreadClassNegative.test_val_greater_than_valid_fail ___________
[gw0] win32 -- Python 3.7.10 C:\Users\<user>\AppData\Local\Continuum\anaconda3\envs\venv\python.exe

self = <tests.unit.test_thread_class.TestThreadClassNegative testMethod=test_val_greater_than_valid_fail>

    def test_val_greater_than_valid_fail(self):
        get_data = ThreadClass(li=[2, 3, 5, 11, 6])
        with self.assertRaises(ValueError):
            get_data.start()
>           get_data.join()
E           AssertionError: ValueError not raised

test_thread_class.py:9: AssertionError
---------------------------- Captured stderr call -----------------------------
Exception in thread Thread-1:
Traceback (most recent call last):
  File "C:\Users\<user>\AppData\Local\Continuum\anaconda3\envs\venv\lib\threading.py", line 926, in _bootstrap_inner
    self.run()
  File "C:\Users\<user>\projects\pytest_project\thread_class.py", line 27, in run
    self.__fun()
  File "C:\Users\<user>\projects\pytest_project\thread_class.py", line 19, in __fun
    p = self.__validate(i)
  File "C:\Users\<user>\projects\pytest_project\thread_class.py", line 12, in __validate
    raise ValueError("Given value is greater than 10")
ValueError: Given value is greater than 10

Why is that test case failing? What am I doing wrong and how can I fix it?


Solution

  • The exception is being raised in the child thread, terminates it, and thats the end for it.

    That's because an exception raised in a child thread will not propagate to the caller thread.

    A possible solution will be creating a class to catch the exception in run() and raise it in join():

    EDIT: I messed up the subclassing, tested it this time, should work

    class Propogate(Thread):
        def __init__(self, cls):
            super().__init__(target=cls.run)
            self.ex = None
    
        def run(self):
            try:
                self._target()
            except BaseException as e:
                self.ex = e
    
        def join(self):
            super().join()
            if self.ex is not None:
                raise self.ex
    

    Usage by changing this line:

    get_data = Propogate(ThreadClass(li = [2,3,5,11,6]))
    

    And not inheriting from Thread on the other class:

    class ThreadClass():
    

    My answer is adapted from this answer