Search code examples
pythondecoratorpython-decoratorsmonkeypatching

Python: Patch the print function for a particular function call? (Decorator for printing the recursion tree)


I wrote a decorator for printing the recursion tree produced by some function call.

from functools import wraps

def printRecursionTree(func):
    global _recursiondepth
    _print = print
    _recursiondepth = 0

    def getpads():
        if _recursiondepth == 0:
            strFn    = '{} └──'.format(' │  ' * (_recursiondepth-1))
            strOther = '{}  ▒▒'.format(' │  ' * (_recursiondepth-1))
            strRet   = '{}    '.format(' │  ' * (_recursiondepth-1))
        else:
            strFn    = '    {} ├──'.format(' │  ' * (_recursiondepth-1))
            strOther = '    {} │▒▒'.format(' │  ' * (_recursiondepth-1))
            strRet   = '    {} │  '.format(' │  ' * (_recursiondepth-1))

        return strFn, strRet, strOther

    def indentedprint():
        @wraps(print)
        def wrapper(*args, **kwargs):
            strFn, strRet, strOther = getpads()
            _print(strOther, end=' ')
            _print(*args, **kwargs)
        return wrapper


    @wraps(func)
    def wrapper(*args, **kwargs):
        global _recursiondepth
        global print

        strFn, strRet, strOther = getpads()

        if args and kwargs:
            _print(strFn, '{}({}, {}):'.format(func.__qualname__, ', '.join(args), kwargs))
        else:
            _print(strFn, '{}({}):'.format(func.__qualname__, ', '.join(map(str, args)) if args else '', kwargs if kwargs else ''))
        _recursiondepth += 1
        print, backup = indentedprint(), print
        retval = func(*args, **kwargs)
        print = backup
        _recursiondepth -= 1
        _print(strRet, '╰', retval)
        if _recursiondepth == 0:
            _print()
        return retval

    return wrapper

Example usage:

@printRecursionTree
def fib(n):
    if n <= 1:
        print('Base Case')
        return n
    print('Recursive Case')
    return fib(n-1) + fib(n-2)

# This works with mutually recursive functions too,
# since the variable _recursiondepth is global
@printRecursionTree
def iseven(n):
    print('checking if even')
    if n == 0: return True
    return isodd(n-1)

@printRecursionTree
def isodd(n):
    print('checking if odd')
    if n == 0: return False
    return iseven(n-1)

iseven(5)
fib(5)

'''Prints:

└── iseven(5):
     │▒▒ checking if even
     │▒▒ Note how the print
     │▒▒ statements get nicely indented
     ├── isodd(4):
     │   │▒▒ checking if odd
     │   ├── iseven(3):
     │   │   │▒▒ checking if even
     │   │   │▒▒ Note how the print
     │   │   │▒▒ statements get nicely indented
     │   │   ├── isodd(2):
     │   │   │   │▒▒ checking if odd
     │   │   │   ├── iseven(1):
     │   │   │   │   │▒▒ checking if even
     │   │   │   │   │▒▒ Note how the print
     │   │   │   │   │▒▒ statements get nicely indented
     │   │   │   │   ├── isodd(0):
     │   │   │   │   │   │▒▒ checking if odd
     │   │   │   │   │   ╰ False
     │   │   │   │   ╰ False
     │   │   │   ╰ False
     │   │   ╰ False
     │   ╰ False
     ╰ False

 └── fib(5):
     │▒▒ Recursive Case
     ├── fib(4):
     │   │▒▒ Recursive Case
     │   ├── fib(3):
     │   │   │▒▒ Recursive Case
     │   │   ├── fib(2):
     │   │   │   │▒▒ Recursive Case
     │   │   │   ├── fib(1):
     │   │   │   │   │▒▒ Base Case
     │   │   │   │   ╰ 1
     │   │   │   ├── fib(0):
     │   │   │   │   │▒▒ Base Case
     │   │   │   │   ╰ 0
     │   │   │   ╰ 1
     │   │   ├── fib(1):
     │   │   │   │▒▒ Base Case
     │   │   │   ╰ 1
     │   │   ╰ 2
     │   ├── fib(2):
     │   │   │▒▒ Recursive Case
     │   │   ├── fib(1):
     │   │   │   │▒▒ Base Case
     │   │   │   ╰ 1
     │   │   ├── fib(0):
     │   │   │   │▒▒ Base Case
     │   │   │   ╰ 0
     │   │   ╰ 1
     │   ╰ 3
     ├── fib(3):
     │   │▒▒ Recursive Case
     │   ├── fib(2):
     │   │   │▒▒ Recursive Case
     │   │   ├── fib(1):
     │   │   │   │▒▒ Base Case
     │   │   │   ╰ 1
     │   │   ├── fib(0):
     │   │   │   │▒▒ Base Case
     │   │   │   ╰ 0
     │   │   ╰ 1
     │   ├── fib(1):
     │   │   │▒▒ Base Case
     │   │   ╰ 1
     │   ╰ 2
     ╰ 5
'''

This example code works correctly as long as it is in the same file where the decorator is defined.

If however one imports the decorator from some module, the print statements no longer get indented.

I understand that this behaviour arises because the print statement patched by the decorator is a global for it's own module, is not shared across modules.

  1. How do I fix this?
  2. Is there a better way of patching a function only for a particular call to another function?

Solution

  • You can change the behavior of the builtin print function for all modules by replacing it in the builtins module.

    So change your assignments to global variable print with assignments to builtins.print (after importing builtins):

    import builtins
    
    ...
    
        @wraps(func)
        def wrapper(*args, **kwargs):
            global _recursiondepth # no more need for global print up here
    
            strFn, strRet, strOther = getpads()
    
            if args and kwargs:
                _print(strFn, '{}({}, {}):'.format(func.__qualname__, ', '.join(args), kwargs))
            else:
                _print(strFn, '{}({}):'.format(func.__qualname__, ', '.join(map(str, args)) if args else '', kwargs if kwargs else ''))
            _recursiondepth += 1
            builtins.print, backup = indentedprint(), print   # change here
            retval = func(*args, **kwargs)
            builtins.print = backup                           # and here
            _recursiondepth -= 1
            _print(strRet, '╰', retval)
            if _recursiondepth == 0:
                _print()
            return retval