Search code examples
pythonabstract-syntax-treepython-ast

Using Python AST Module, parse callables used as arguments


I'm writing a script that will walk through various Python modules, and hunt for instances of calls to fx.bind(). The first argument to this function is a string representing the hotkey, and the second argument is a function that will be run when the hotkey is pressed. For example, fx.bind('o', my_func).

I have created a subclass of ast.NodeVisitor, that implements visit_Call(), to check all Call nodes. What I'm wondering is how I can parse an actually instance of the callable that's passed as the second argument?

The idea being I can then store my own dictionary, where the keys are the functions, and values are the hotkeys they're assigned too. Ie.{<function my_func at 0x0>: 'o'}

Any help would be much appreciated!


Solution

  • import ast
    import collections
    from typing import Callable
    from typing import Any
    
    
    class NV(ast.NodeVisitor):
        def __init__(self, module: str, source: ast.AST) -> None:
            self.module = __import__(module)
            self.funcs: dict[Callable[[str], Any], list[str]] = collections.defaultdict(list)
            self.visit(source)
    
        def visit_Call(self, node: ast.AST) -> None:
            match node:
                case ast.Call(
                    func=ast.Attribute(
                        value=ast.Name(id='fx'),
                        attr='bind'
                    ),
                    args=[
                        ast.Constant(
                            value=cb_value
                        ),
                        ast.Name(
                            id=cb_id
                        )
                    ]
                ):
                    self.funcs[getattr(self.module, cb_id)].append(cb_value)
    
                case ast.Call(
                    func=ast.Attribute(
                        value=ast.Name(id='fx'),
                        attr='bind'
                    ),
                    args=[
                        ast.Constant(
                            value=cb_value
                        ),
                        ast.Attribute(
                            value=ast.Name(id=cb_module),
                            attr=cb_id
                        )
                    ]
                ):
                    module = __import__(cb_module)
                    self.funcs[getattr(module, cb_id)].append(cb_value)
    
            self.generic_visit(node)
    
    
    with open('pymodule.py') as source_file:
        source_tree = ast.parse(source_file.read())
        node_visitor = NV('pymodule', source_tree)
        print(node_visitor.funcs)
    

    Functionality:

    1. The script visits all Call-nodes
    2. It checks through the match-statement if it is one of the following formats:
      • fx.bind('o', callback)
      • fx.bind('o', other_module.callback)
    3. Further processing:
      • first format: the callback can be found in the same module as the call.
      • second format: the callback needs to be fetched from other_module
    4. puts the result in the dictionary

    Differences to the needed solution:

    • I felt free to change the wanted dict structure to be a defaultdict. This prevents overriding when using the same callback twice or more in the same module.

    Notes:

    • Explicitly written using Python3.10 using match. It can be adjusted to work if if-statements, but that would be even messier. If needed though, it is possible.
    • It is explicitly not guaranteed that this script is error-free. I could have overlooked edge-cases that I didn't saw reading the question.