Search code examples
pythonpython-2.7abstract-syntax-tree

Can I handle imports in an Abstract Syntax Tree?


I want to parse and check config.py for admissible nodes. config.py can import other config files, which also must be checked.

Is there any functionality in the ast module to parse ast.Import and ast.ImportFrom objects to ast.Module objects?

Here is a code example, I am checking a configuration file (path_to_config), but I want to also check any files that it imports:

with open(path_to_config) as config_file:
    ast_tree = ast.parse(config_file.read())
    for script_object in ast_tree.body:
        if isinstance(script_object, ast.Import):
            # Imported file must be checked too
        elif isinstance(script_object, ast.ImportFrom):
            # Imported file must be checked too
        elif not _is_admissible_node(script_object):
            raise Exception("Config file '%s' contains unacceptable statements" % path_to_config)

Solution

  • This is a little more complex than you think. from foo import name is a valid way of importing both an object defined in the foo module, and the foo.name module, so you may have to try both forms to see if they resolve to a file. Python also allows for aliases, where code can import foo.bar, but the actual module is really defined as foo._bar_implementation and made available as an attribute of the foo package. You can't detect all of these cases purely by looking at Import and ImportFrom nodes.

    If you ignore those cases and only look at the from name, then you'll still have to turn the module name into a filename, then parse the source from the file, for each import.

    In Python 2 you can use imp.find_module to get an open file object for the module (*). You want to keep the full module name around when parsing each module, because you'll need it to help you figure out package-relative imports later on. imp.find_module() can't handle package imports so I created a wrapper function:

    import imp
    
    _package_paths = {}
    def find_module(module):
        # imp.find_module can't handle package paths, so we need to do this ourselves
        # returns an open file object, the filename, and a flag indicating if this
        # is a package directory with __init__.py file.
        path = None
        if '.' in module:
            # resolve the package path first
            parts = module.split('.')
            module = parts.pop()
            for i, part in enumerate(parts, 1):
                name = '.'.join(parts[:i])
                if name in _package_paths:
                    path = [_package_paths[name]]
                else:
                    _, filename, (_, _, type_) = imp.find_module(part, path)
                    if type_ is not imp.PKG_DIRECTORY:
                        # no Python source code for this package, abort search
                        return None, None
                    _package_paths[name] = filename
                    path = [filename]
        source, filename, (_, _, type_) = imp.find_module(module, path)
        is_package = False
        if type_ is imp.PKG_DIRECTORY:
            # load __init__ file in package
            source, filename, (_, _, type_) = imp.find_module('__init__', [filename])
            is_package = True
        if type_ is not imp.PY_SOURCE:
            return None, None, False
        return source, filename, is_package
    

    I'd also track what module names you already imported so you don't process them twice; use the name from the spec object to make sure you track their canonical names.

    Use a stack to process all the modules:

    with open(path_to_config) as config_file:
        # stack consists of (modulename, ast) tuples
        stack = [('', ast.parse(config_file.read()))]
    
    seen = set()
    while stack:
        modulename, ast_tree = stack.pop()
        for script_object in ast_tree.body:
            if isinstance(script_object, (ast.Import, ast.ImportFrom)):
                names = [a.name for a in script_object.names]
                from_names = []
                if hasattr(script_object, 'level'):  # ImportFrom
                    from_names = names
                    name = script_object.module
                    if script_object.level:
                        package = modulename.rsplit('.', script_object.level - 1)[0]
                        if script_object.module:
                            name = "{}.{}".format(name, script_object.module)
                        else:
                            name = package
                    names = [name]
                for name in names:
                    if name in seen:
                        continue
                    seen.add(name)
                    source, filename, is_package = find_module(name)
                    if source is None:
                        continue
                    if is_package and from_names:
                        # importing from a package, assume the imported names
                        # are modules
                        names += ('{}.{}'.format(name, fn) for fn in from_names)
                        continue
                    with source:
                        module_ast = ast.parse(source.read(), filename)
                    stack.append((name, module_ast))
                    
            elif not _is_admissible_node(script_object):
                raise Exception("Config file '%s' contains unacceptable statements" % path_to_config)
    

    In case of from foo import bar imports, if foo is a package then foo/__init__.py is skipped and it is assumed that bar will be a module.


    (*) imp.find_module() is deprecated for Python 3 code. On Python 3 you would use importlib.util.find_spec() to get the module loader spec, and then use the ModuleSpec.origin attribute to get the filename. importlib.util.find_spec() knows how to handle packages.