Search code examples
pythonpyparsing

pyparsing grammer to extract portions of Python snippet


I have a Python snippets which look like the following:

fn = HiveExecOperator(
        task_id="abc",
        hql="abc.hql",
        dq_sql=DQCheck("pqr")
        .is_within_range(
            "COUNT(DISTINCT billing_association_type)",
            "type_cnts",
            lower=1.0,
            upper=float("inf"),
        )
        .build(),
        dag=main_dag,
    )

I would like to define a grammar which lets me look at the key-value pair in the parameter list of the function HiveExecOperator without breaking down the nested ones. For example - I am interested in getting back a list:

[task_id="abc", 
 hql="abc.hql",
 ...
 dq_sql=DQCheck("pqr")
        .is_within_range(
            "COUNT(DISTINCT billing_association_type)",
            "type_cnts",
            lower=1.0,
            upper=float("inf"),
        )
        .build(),
...]

I tried doing the following:

assignment = variable + '=' + "HiveExecOperator" + nestedExpr('(', ')').setParseAction(lambda x: print(x))

parameters.transformString(python_snippet)

Output from setParseAction is:

['fn', '=', 'HiveExecOperator(']
['task_id', '=', '"abc",']
['hql', '=', '"abc.hql",']
['dq_sql', '=', 'DQCheck("stage.billing_associations")']
['lower', '=', '1.0,']
['upper', '=', 'float("inf"),']
...

Any help will be appreciated.


Solution

  • As mentioned by mkrieger1, you can use the ast built-in Python library.

    In Python 3.9 (or later), there is the ast.unparse function that can turn the ast.Node into a string.

    import ast
    
    mycode = """\
    fn = HiveExecOperator(
            task_id="abc",
            hql="abc.hql",
            dq_sql=DQCheck("pqr")
            .is_within_range(
                "COUNT(DISTINCT billing_association_type)",
                "type_cnts",
                lower=1.0,
                upper=float("inf"),
            )
            .build(),
            dag=main_dag,
        )
    """
    
    root = ast.parse(mycode)
    calls = [n for n in ast.walk(root) if isinstance(n, ast.Call)]
    first_call = calls[0]
    target_list = [(k.arg, ast.unparse(k.value)) for k in first_call.keywords]
    print(target_list)
    

    which gives

    [
       ('task_id', "'abc'"),
       ('hql', "'abc.hql'"),
       ('dq_sql', "DQCheck('pqr').is_within_range('COUNT(DISTINCT billing_association_type)', 'type_cnts', lower=1.0, upper=float('inf')).build()"),
       ('dag', 'main_dag')
    ]