I wrote a decorator for printing the recursion tree produced by some function call.
from functools import wraps
def printRecursionTree(func):
global _recursiondepth
_print = print
_recursiondepth = 0
def getpads():
if _recursiondepth == 0:
strFn = '{} └──'.format(' │ ' * (_recursiondepth-1))
strOther = '{} ▒▒'.format(' │ ' * (_recursiondepth-1))
strRet = '{} '.format(' │ ' * (_recursiondepth-1))
else:
strFn = ' {} ├──'.format(' │ ' * (_recursiondepth-1))
strOther = ' {} │▒▒'.format(' │ ' * (_recursiondepth-1))
strRet = ' {} │ '.format(' │ ' * (_recursiondepth-1))
return strFn, strRet, strOther
def indentedprint():
@wraps(print)
def wrapper(*args, **kwargs):
strFn, strRet, strOther = getpads()
_print(strOther, end=' ')
_print(*args, **kwargs)
return wrapper
@wraps(func)
def wrapper(*args, **kwargs):
global _recursiondepth
global print
strFn, strRet, strOther = getpads()
if args and kwargs:
_print(strFn, '{}({}, {}):'.format(func.__qualname__, ', '.join(args), kwargs))
else:
_print(strFn, '{}({}):'.format(func.__qualname__, ', '.join(map(str, args)) if args else '', kwargs if kwargs else ''))
_recursiondepth += 1
print, backup = indentedprint(), print
retval = func(*args, **kwargs)
print = backup
_recursiondepth -= 1
_print(strRet, '╰', retval)
if _recursiondepth == 0:
_print()
return retval
return wrapper
Example usage:
@printRecursionTree
def fib(n):
if n <= 1:
print('Base Case')
return n
print('Recursive Case')
return fib(n-1) + fib(n-2)
# This works with mutually recursive functions too,
# since the variable _recursiondepth is global
@printRecursionTree
def iseven(n):
print('checking if even')
if n == 0: return True
return isodd(n-1)
@printRecursionTree
def isodd(n):
print('checking if odd')
if n == 0: return False
return iseven(n-1)
iseven(5)
fib(5)
'''Prints:
└── iseven(5):
│▒▒ checking if even
│▒▒ Note how the print
│▒▒ statements get nicely indented
├── isodd(4):
│ │▒▒ checking if odd
│ ├── iseven(3):
│ │ │▒▒ checking if even
│ │ │▒▒ Note how the print
│ │ │▒▒ statements get nicely indented
│ │ ├── isodd(2):
│ │ │ │▒▒ checking if odd
│ │ │ ├── iseven(1):
│ │ │ │ │▒▒ checking if even
│ │ │ │ │▒▒ Note how the print
│ │ │ │ │▒▒ statements get nicely indented
│ │ │ │ ├── isodd(0):
│ │ │ │ │ │▒▒ checking if odd
│ │ │ │ │ ╰ False
│ │ │ │ ╰ False
│ │ │ ╰ False
│ │ ╰ False
│ ╰ False
╰ False
└── fib(5):
│▒▒ Recursive Case
├── fib(4):
│ │▒▒ Recursive Case
│ ├── fib(3):
│ │ │▒▒ Recursive Case
│ │ ├── fib(2):
│ │ │ │▒▒ Recursive Case
│ │ │ ├── fib(1):
│ │ │ │ │▒▒ Base Case
│ │ │ │ ╰ 1
│ │ │ ├── fib(0):
│ │ │ │ │▒▒ Base Case
│ │ │ │ ╰ 0
│ │ │ ╰ 1
│ │ ├── fib(1):
│ │ │ │▒▒ Base Case
│ │ │ ╰ 1
│ │ ╰ 2
│ ├── fib(2):
│ │ │▒▒ Recursive Case
│ │ ├── fib(1):
│ │ │ │▒▒ Base Case
│ │ │ ╰ 1
│ │ ├── fib(0):
│ │ │ │▒▒ Base Case
│ │ │ ╰ 0
│ │ ╰ 1
│ ╰ 3
├── fib(3):
│ │▒▒ Recursive Case
│ ├── fib(2):
│ │ │▒▒ Recursive Case
│ │ ├── fib(1):
│ │ │ │▒▒ Base Case
│ │ │ ╰ 1
│ │ ├── fib(0):
│ │ │ │▒▒ Base Case
│ │ │ ╰ 0
│ │ ╰ 1
│ ├── fib(1):
│ │ │▒▒ Base Case
│ │ ╰ 1
│ ╰ 2
╰ 5
'''
This example code works correctly as long as it is in the same file where the decorator is defined.
If however one imports the decorator from some module, the print statements no longer get indented.
I understand that this behaviour arises because the print statement patched by the decorator is a global for it's own module, is not shared across modules.
You can change the behavior of the builtin print function for all modules by replacing it in the builtins
module.
So change your assignments to global variable print
with assignments to builtins.print
(after importing builtins
):
import builtins
...
@wraps(func)
def wrapper(*args, **kwargs):
global _recursiondepth # no more need for global print up here
strFn, strRet, strOther = getpads()
if args and kwargs:
_print(strFn, '{}({}, {}):'.format(func.__qualname__, ', '.join(args), kwargs))
else:
_print(strFn, '{}({}):'.format(func.__qualname__, ', '.join(map(str, args)) if args else '', kwargs if kwargs else ''))
_recursiondepth += 1
builtins.print, backup = indentedprint(), print # change here
retval = func(*args, **kwargs)
builtins.print = backup # and here
_recursiondepth -= 1
_print(strRet, '╰', retval)
if _recursiondepth == 0:
_print()
return retval