Search code examples
pythonabstract-syntax-tree

How do I apply a ast.NodeTransformer to imports?


I'd like to a apply a NodeTransformer not just to the current file's AST, but to any code imported as well. If you run the below code, you will note that the transformer works, but only for the single file read and parsed. How would I modify this code to apply the transformer to any imports in the parsed code?

a.py:

from b import q


def r(a):
    return q(a) + 5

b.py:

def q(n):
    return r(n + 1)


def r(n):
    return n

Main.py:

import ast
import astor


class trivial_transformer(ast.NodeTransformer):
    def visit_FunctionDef(self, node):
        body = []
        for line in node.body:
            body.append(
                ast.Expr(
                    ast.Call(func=ast.Name('print', ctx=ast.Load()),
                             args=[ast.Str(s="Doing: "+astor.to_source(line).strip())],
                             keywords=[])))
            body.append(self.generic_visit(line))
        node.body = body
        return node


parsed_ast = ast.fix_missing_locations(trivial_transformer().visit(ast.parse(open('a.py','r').read())))
g = {}
eval(compile(parsed_ast, '<source>', 'exec'), g)
print(g['r'](5))

This yields:

Doing: return q(a) + 5
11

But I'd like it to yield:

Doing: return q(a) + 5
Doing: return r(n + 1)
Doing: return n
11

Solution

  • Well, it took some doing, but I got it (woo):

    import ast
    import astor
    import importlib
    import sys
    
    class trivial_transformer(ast.NodeTransformer):
    
        def processImport(self, imp):
            if imp not in sys.modules:
                spec = importlib.util.find_spec(imp)
                helper = importlib.util.module_from_spec(spec)
                parsed_dep = ast.fix_missing_locations(self.visit(ast.parse(spec.loader.get_source(imp))))
                exec(compile(parsed_dep, imp, 'exec'), helper.__dict__)
                sys.modules[imp] = helper
    
        def visit_ImportFrom(self, node):
            self.processImport(node.module)
            return node
    
        def visit_Import(self, node):
            for i in node.names:
                self.processImport(i.name)
            return node
    
        def visit_FunctionDef(self, node):
            body = []
            for line in node.body:
                body.append(
                    ast.Expr(
                        ast.Call(func=ast.Name('print', ctx=ast.Load()),
                                 args=[ast.Str(s="Doing: "+astor.to_source(line).strip())],
                                 keywords=[])))
                body.append(self.generic_visit(line))
            node.body = body
            return node
    
    
    init = 'a'
    trivial_transformer().processImport(init)
    
    import a
    a.r(5)