I have a (unittest-based) test suite for my python project.
Here I have my test classes, with my test methods, etc...
In (some of my) tests I call a function to initialize the scenarios of the tests. Let's call this function generate_scenario(...)
, and it has a bunch of parameters.
I was wondering if I could write an additional python code that could find all the times generate_scenario(...)
is called and with the parameter passed, so I could check if all the "possible" scenarios are actually generated.
Ideally I want an additional test module to check that.
wondering if one could write code to find all the times generate_scenario(...) is called
Consult the documentation for the “inspect” and “dis” builtin modules. They are very helpful for working with code objects and source files.
I enclose a demonstration.
There's at least two ways to address your use case:
simple text processing (glorified grep
), and
examining the results of cPython's parse.
from importlib import import_module
from inspect import getsource, isfunction, isgenerator
from pathlib import Path
from types import FunctionType, MethodType, ModuleType
from typing import Callable, Generator, Iterable, NamedTuple
from unittest import TestCase
from unittest.main import TestProgram
import dis
import io
import os
import re
import sys
def find_callable_functions(module: ModuleType | type) -> list[Callable]:
"""Finds callables within a module, including functions and classes."""
return [
for obj in module.__dict__.values()
if callable(obj) and isinstance(obj, (FunctionType, MethodType, type))
# cf inspect.{isfunction, ismethod, isclass}
def find_callable_matches(
module: ModuleType | type, needle: str, verbose: bool = False
) -> Generator[Callable, None, None]:
for obj in module.__dict__.values():
if callable(obj) and isinstance(obj, (FunctionType, MethodType, type)):
if not isgenerator(obj) and isfunction(obj):
buf = io.StringIO()
dis.dis(obj, file=buf)
names = obj.__code__.co_names
if needle in buf.getvalue() and needle in names:
yield obj
if verbose:
# print(dis._disassemble_bytes(code, names=names))
# lines, start = findsource(obj)
# print("".join(lines[start : start + 5]), "\n")
# dis.disassemble(obj.__code__)
class Source(NamedTuple):
"""coordinates of a source code location"""
file: Path
line: int
src: list[str]
def find_functions_in(source_file: Path) -> Generator[Source, None, None]:
decorator = re.compile(r"^\s*@")
record_delimiter = re.compile(r"^(\s*def |if __name__ == .__main__.)")
record = Source(Path("/dev/null"), -1, []) # sentinel
with open(source_file) as fin:
for i, line in enumerate(fin):
if record_delimiter.match(line):
if record.line > 0:
yield record
record = Source(file=source_file.resolve(), line=i + 1, src=[])
if not decorator.match(line):
if record.line > 0:
yield record
def find_functions_under(
paths: Iterable[Path], needle
) -> Generator[Source, None, None]:
for path in paths:
if path.is_file() and path.suffix == ".py":
for record in find_functions_in(path):
if needle in "".join(record.src):
yield record
# file = f"{record.file.relative_to(os.getcwd())}"
# m = import_module(file.replace("/", ".").removesuffix(".py"))
class FirstClass:
def __init__(self, x):
self.x = x
def generate_scenario(self, a, b, c):
self.x += a + b + c
def run_scenario(self):
self.generate_scenario(1, 2, 3)
class SecondClass:
def __init__(self, y):
self.y = y
def generate_scenario(self, a, b, c):
self.y += a * b * c
def run_scenario(self):
class UnrelatedClass:
def __init__(self):
self.z = None
class TestFindFunctions(TestCase):
def test_find_callable_functions(self) -> None:
"<class '_frozen_importlib.FrozenImporter'>",
self.assertEqual(os, import_module("os"))
def test_find_callable_matches(self) -> None:
list(find_callable_matches(FirstClass, "generate_scenario")),
def test_find_functions(self) -> None:
source_records = list(find_functions_in(Path(__file__)))
self.assertEqual(15, len(source_records))
def test_find_functions_under(self, verbose: bool = False) -> None:
source_folder = Path(__file__).parent
glob = source_folder.glob("**/*.py")
records = list(find_functions_under(glob, "generate_scenario"))
self.assertEqual(6, len(records))
if verbose:
for record in records: