Search code examples
pythonsql-parser

Parse CASE WHEN statements with sqlparse


I have the following SQL query and would like to parse it using sqlparse

import sqlparse

query =  """
select SUM(case when(A.dt_unix<=86400
                     and B.flag="V") then 1
           end) as TEST_COLUMN_1,
       SUM(case when(A.Amt - B.Amt > 0
                     and B.Cat1 = "A"
                     and (B.Cat2 = "M"
                          or B.Cat3 = "C"
                          or B.Cat4 = "B")
                     and B.Cat5 is NULL) then 1
           end) as TEST_COLUMN_2
from test_table A
left join test_table_2 as B on A.ID=B.ID
where A.DT >B.DT
group by A.ID
"""

query_tokens = sqlparse.parse(query)[0].tokens
print(query_tokens)

would give all the tokens included in the SQL statement:

[<Newline ' ' at 0x7FAA62BD9F48>, <DML 'select' at 0x7FAA62BE7288>, <Whitespace ' ' at 0x7FAA62BE72E8>, <IdentifierList 'SUM(ca...' at 0x7FAA62BF7CF0>, <Newline ' ' at 0x7FAA62BF6288>, <Keyword 'from' at 0x7FAA62BF62E8>, <Whitespace ' ' at 0x7FAA62BF6348>, <Identifier 'test_t...' at 0x7FAA62BF7570>, <Newline ' ' at 0x7FAA62BF64C8>, <Keyword 'left j...' at 0x7FAA62BF6528>, <Whitespace ' ' at 0x7FAA62BF6588>, <Identifier 'test_t...' at 0x7FAA62BF7660>, <Whitespace ' ' at 0x7FAA62BF67C8>, <Keyword 'on' at 0x7FAA62BF6828>, <Whitespace ' ' at 0x7FAA62BF6888>, <Comparison 'A.ID=B...' at 0x7FAA62BF7B10>, <Newline ' ' at 0x7FAA62BF6B88>, <Where 'where ...' at 0x7FAA62BF28B8>, <Keyword 'group' at 0x7FAA62BD9E88>, <Whitespace ' ' at 0x7FAA62BD93A8>, <Keyword 'by' at 0x7FAA62BD9EE8>, <Whitespace ' ' at 0x7FAA62C1CEE8>, <Identifier 'A.ID' at 0x7FAA62BF2F48>, <Newline ' ' at 0x7FAA62BF6C48>]

How can I parse these tokens in order to process CASE WHEN statements in a way that I can extract all the conditions and maintain their precedence as defined by the use of parentheses. I was not able to find any relevant examples in the documentation.

Any thoughts on this?


Solution

  • The project is indeed a little underdocumented. I looked at the examples and scanned the source code a little. The documentation unfortunately doesn't include all methods on the Token and TokenList classes that are useful for this task.

    For example, an important but omitted method is the TokenList.get_sublists() method, which lets you traverse over nested token lists more easily than other methods do; the TokenList.flatten() method only yields ungrouped tokens in the tree, whereas CASE is a grouped token, so going purely by the documentation you might find it hard to do something useful with the parsed token tree.

    Another handy method that I noticed in the codebase is the TokenList._pprint_tree() method, which dumps out the current token tree to stdout. This is very helpful when trying to write code that analyses the tree.

    All in all my overall impression of sqlparse is that it is less of a parsing library than a tool to re-format SQL. It includes a good parser but doesn't include the tools necessary to make general use of the tree it produces.

    What is really missing in the library is a base node visitor class such as that provided by the Python ast module, or a tree node walker, again like the ast module provides. Either is easy enough to build yourself, luckily:

    from collections import deque
    from sqlparse.sql import TokenList
    
    class SQLTokenVisitor:
        def visit(self, token):
            """Visit a token."""
            method = 'visit_' + type(token).__name__
            visitor = getattr(self, method, self.generic_visit)
            return visitor(token)
    
        def generic_visit(self, token):
            """Called if no explicit visitor function exists for a node."""
            if not isinstance(token, TokenList):
                return
            for tok in token:
                self.visit(tok)
    
    def walk_tokens(token):
        queue = deque([token])
        while queue:
            token = queue.popleft()
            if isinstance(token, TokenList):
                queue.extend(token)
            yield token
    

    Now you can use either to access the Case nodes:

    statement, = sqlparse.parse(query)
    
    class CaseVisitor(SQLTokenVisitor):
        """Build a list of SQL Case nodes
    
          The .cases list is a list of (condition, value) tuples per CASE statement
    
        """
        def __init__(self):
            self.cases = []
    
        def visit_Case(self, token):
            branches = []
            for when, then_ in token.get_cases():
                branches
            self.cases.append(token.get_cases())
    
    visitor = CaseVisitor()
    visitor.visit(statement)
    cases = visitor.cases
    

    or

    statement, = sqlparse.parse(query)
    
    cases = []
    for token in walk_tokens(statement):
        if isinstance(token, sqlparse.sql.Case):
            cases.append(token.get_cases())
    

    The difference between the walk_tokens() and NodeVisitor patterns is negligible in this example, but we are simply extracting the separated tokens for each of the CASE statements, with no processing of the WHEN ... THEN ... tokens. In the NodeVisitor pattern you'd set more attributes on the current visitor instance to 'switch gears' and capture further information about those subtree tokens in more visit_.... methods, which may be easier to follow than a nested for loop over a generator.

    On the other hand, with the walk_tokens() generator, if you create a separate variable to reference the generator, you can hand over iteration to helper functions:

    all_tokens = walk_tokens(stamement)
    for token in walk_tokens(statement):
        if isinstance(token, sqlparse.sql.Case):
            branches = extract_branches(all_tokens)
    

    where extract_branches would further iterate until it came to the end of the case statement.