Search code examples
pythonpython-3.xabstract-syntax-tree

Using Python AST to get the value from keyword of a context manager


I have Python source code that is analyzed by an external tool I am writing. The code contains a context manager and I'm trying to get the value (list) of the keyword 'tags' inside the context manager called 'dag'.

from airflow.models import DAG

with DAG(
    dag_id='my_dag',
    tags=['dbt', 'marketing', 'schema', 'data_vis']
) as dag:
    pass

The desired output is:

['dbt', 'marketing', 'schema', 'data_vis']

I think I can use the ast module to parse the file. How can I achieve this?

Solution

The best solution for my question at the moment was:

# dag.py

from airflow.models import DAG

with DAG(
    dag_id='my_dag',
    tags=['dbt', 'marketing', 'schema', 'data_vis']
) as dag:
    pass
# ast_explorer.py

import ast

with open('dag.py', 'r') as f:
    tree = ast.parse(f.read())

for node in tree.body:
    if isinstance(node, ast.With):
        for item in node.items:
            if isinstance(item.context_expr, ast.Call) and item.context_expr.func.id == 'DAG':
                for arg in item.context_expr.keywords:
                    if arg.arg == 'tags':
                        tag_list = [s.s for s in arg.value.elts]

print(tag_list)
>>> ['dbt', 'marketing', 'schema', 'data_vis']

Solution

  • You can traverse the ast object and when you encounter an ast.withitem with a corresponding call of an object named DAG, iterate over all the keyword objects and save the values associated with keys tags:

    import ast
    src = """
    from airflow.models import DAG
    
    with DAG(
       dag_id='my_dag',
       tags=['dbt', 'marketing', 'schema', 'data_vis']
    ) as dag:
        pass
    """
    cxt_m = [i for i in ast.walk(ast.parse(src)) if isinstance(i, ast.withitem) 
             and isinstance(i.context_expr, ast.Call) and i.context_expr.func.id == 'DAG']
    tags = [ast.literal_eval(ast.unparse(j.value)) for i in cxt_m 
            for j in i.context_expr.keywords if j.arg == 'tags']
    print(tags)
    

    Output:

    [['dbt', 'marketing', 'schema', 'data_vis']]