Search code examples
pythonabstract-syntax-treelibcst

Find if-nodes that are immediately followed by a raise-node in Python with libcst


right now I am working on a project for a university course. I got some random functions and most of them have a if-raise-statement in the code somewhere.

I try to find those, but only those 1 or 2 lines. I transform the functions into an AST and then visit it using libcst. I extend the visitor class, search for if-nodes and then match for raise-nodes. However this also matches and saves statements that are like if-if-raise or if-else-raise.

I hope someone can help me on how to modify the matcher to only match if-nodes directly followed by 1 raise node. (Sequence wildcard matchers would be awesome, but as far as I understand it they cannot be matched to find sequences of nodes.)

import libcst as cst
import libcst.matchers as m

class FindIfRaise(cst.CSTVisitor):

    if_raise = [] 

    # INIT
    def __init__(self):
        self.if_raise = []

    def visit_If(self, node: cst.If):
        try:
            if m.findall(node, m.Raise()):
                self.if_raise.append(node)

Thanks in advance for any help.


Solution

  • Instead of a node visitor pattern, you can recursively traverse the cst, navigating via the body attribute of each cst object. This way you can track your depth, check for sibling if statements, and only produce raise statements when the desired conditions are met:

    import libcst as cst
    def walk(ct, p = []):
      bd = ct
      while (not isinstance(bd:=getattr(bd, 'body', []), list)): pass
      for i, t in enumerate(bd):
         if isinstance(t, cst._nodes.statement.Raise):
            f = False
            for i in p[::-1]:
               if not isinstance(i, (cst._nodes.statement.IndentedBlock, cst._nodes.statement.SimpleStatementLine)):
                  f = isinstance(i, cst._nodes.statement.If)
                  break
            if f: yield t
         elif isinstance(t, cst._nodes.statement.If):
            if t.orelse is None and (i == len(bd) - 1 or not isinstance(bd[i + 1], cst._nodes.statement.If)):
               yield from walk(t, p + [t])
         else:
             yield from walk(t, p + [t])
    
    s = """
    if something:
       raise Exception
    if something_else:
       pass
    """
    print([*walk(cst.parse_module(s))]) #[], since `if something` is followed by another if-statement
    s1 = """
    if something:
       raise Exception
    elif something_else:
       pass
    """
    print([*walk(cst.parse_module(s1))]) #[], since `if something` is followed by an elif-statement
    s2 = """
    if something:
       raise Exception
    for i in range(10): pass
    """
    print([*walk(cst.parse_module(s2))]) #[Raise(
    #    exc=Name(
    #        value='Exception',
    #        lpar=[],
    #        rpar=[],
    #    ),
    #    cause=None,
    #    whitespace_after_raise=SimpleWhitespace(
    #        value=' ',
    #    ),
    #    semicolon=MaybeSentinel.DEFAULT,
    #)]