Search code examples
pythonreflectionclosuresnested-function

Warn for every (nested) function with free variables (recursively)


I'd like to do the following:

for every nested function f anywhere in this_py_file:
    if has_free_variables(f):
        print warning

Why? Primarily as insurance against the late-binding closure gotcha as described elsewhere. Namely:

>>> def outer():
...     rr = []
...     for i in range(3):
...         def inner():
...             print i
...         rr.append(inner)
...     return rr
... 
>>> for f in outer(): f()
... 
2
2
2
>>> 

And whenever I get warned about a free variable, I would either add an explicit exception (in the rare case that I would want this behaviour) or fix it like so:

...         def inner(i=i):

Then the behaviour becomes more like nested classes in Java (where any variable to be used in an inner class has to be final).

(As far as I know, besides solving the late-binding issue, this will also promote better use of memory, because if a function "closes over" some variables in an outer scope, then the outer scope cannot be garbage collected for as long as the function is around. Right?)

I can't find any way to get hold of functions nested in other functions. Currently, the best way I can think of is to instrument a parser, which seems like a lot of work.


Solution

  • Consider the following function:

    def outer_func():
        outer_var = 1
    
        def inner_func():
            inner_var = outer_var
            return inner_var
    
        outer_var += 1
        return inner_func
    

    The __code__ object can be used to recover the code object of the inner function:

    outer_code = outer_func.__code__
    inner_code = outer_code.co_consts[2]
    

    From this code object, the free variables can be recovered:

    inner_code.co_freevars # ('outer_var',)
    

    You can check whether or not an code object should be inspected with:

    hasattr(inner_code, 'co_freevars') # True
    

    After you get all the functions from your file, this might look something like:

    for func in function_list:
        for code in outer_func.__code__.co_consts[1:-1]:
            if hasattr(code, 'co_freevars'):
                assert len(code.co_freevars) == 0
    

    Someone who knows more about the inner workings can probably provide a better explanation or a more concise solution.